--- /dev/null
+graph() {
+ %0 : int = prim::Constant[value=1]()
+ %1 : int[] = prim::Constant[value=[0, -1]]()
+ %2 : int = prim::Constant[value=0]()
+ %3 : int = prim::Constant[value=6]()
+ %4 : int = prim::Constant[value=2]()
+ %5 : int = prim::Constant[value=3]()
+ %6 : int[] = prim::ListConstruct(%4, %5)
+ %a.1 : Tensor = aten::rand(%6, %3, %2, %1)
+ %8 : int[] = prim::ListConstruct(%4, %5)
+ %9 : Tensor = aten::rand(%8, %3, %2, %1)
+ %a : Tensor = aten::add_(%a.1, %9, %0)
+ return (%a);
+}
--- /dev/null
+graph() {
+ %0 : int = prim::Constant[value=1]()
+ %1 : int[] = prim::Constant[value=[0, -1]]()
+ %2 : int = prim::Constant[value=0]()
+ %3 : int = prim::Constant[value=6]()
+ %4 : int = prim::Constant[value=2]()
+ %5 : int = prim::Constant[value=3]()
+ %6 : int[] = prim::ListConstruct(%4, %5)
+ %a.1 : Tensor = aten::rand(%6, %3, %2, %1)
+ %8 : int[] = prim::ListConstruct(%4, %5)
+ %9 : Tensor = aten::rand(%8, %3, %2, %1)
+ %a.2 : Tensor = aten::add_(%a.1, %9, %0)
+ %11 : int[] = prim::ListConstruct(%4, %5)
+ %b.1 : Tensor = aten::rand(%11, %3, %2, %1)
+ %13 : int[] = prim::ListConstruct(%4, %5)
+ %14 : Tensor = aten::zeros(%13, %3, %2, %1)
+ %15 : Tensor = aten::gt(%a.2, %14)
+ %16 : bool = prim::TensorToBool(%15)
+ %b : Tensor = prim::If(%16)
+ block0() {
+ %18 : int[] = prim::ListConstruct(%4, %5)
+ %19 : Tensor = aten::rand(%18, %3, %2, %1)
+ %b.2 : Tensor = aten::add_(%b.1, %19, %0)
+ -> (%b.2)
+ }
+ block1() {
+ -> (%b.1)
+ }
+ return (%b);
+}
--- /dev/null
+graph(%a.1 : Tensor) {
+ %1 : int = prim::Constant[value=1]()
+ %2 : int[] = prim::Constant[value=[0, -1]]()
+ %3 : int = prim::Constant[value=0]()
+ %4 : int = prim::Constant[value=6]()
+ %5 : int = prim::Constant[value=2]()
+ %6 : int = prim::Constant[value=3]()
+ %7 : int[] = prim::ListConstruct(%5, %6)
+ %8 : Tensor = aten::rand(%7, %4, %3, %2)
+ %a : Tensor = aten::add_(%a.1, %8, %1)
+ return ();
+}
--- /dev/null
+graph(%a : Tensor) {
+ %1 : int = prim::Constant[value=1]()
+ %2 : int[] = prim::Constant[value=[0, -1]]()
+ %3 : int = prim::Constant[value=6]()
+ %4 : int = prim::Constant[value=0]()
+ %5 : int = prim::Constant[value=2]()
+ %6 : int = prim::Constant[value=3]()
+ %l : Tensor[] = prim::ListConstruct()
+ %8 : Tensor[] = aten::append(%l, %a)
+ %c.1 : Tensor = aten::select(%l, %4)
+ %10 : int[] = prim::ListConstruct(%5, %6)
+ %b : Tensor = aten::rand(%10, %3, %4, %2)
+ %12 : int[] = prim::ListConstruct(%5, %6)
+ %13 : Tensor = aten::rand(%12, %3, %4, %2)
+ %c : Tensor = aten::add_(%c.1, %13, %1)
+ return (%b);
+}
--- /dev/null
+graph(%a : Tensor) {
+ %1 : int[] = prim::Constant[value=[0, -1]]()
+ %2 : int = prim::Constant[value=6]()
+ %i.1 : int = prim::Constant[value=0]()
+ %4 : int = prim::Constant[value=2]()
+ %5 : int = prim::Constant[value=3]()
+ %6 : int = prim::Constant[value=9223372036854775807]()
+ %7 : int = prim::Constant[value=1]()
+ %l : Tensor[] = prim::ListConstruct()
+ %9 : Tensor[] = aten::append(%l, %a)
+ %10 : int[] = prim::ListConstruct(%4, %5)
+ %b : Tensor = aten::rand(%10, %2, %i.1, %1)
+ %12 : bool = aten::lt(%i.1, %7)
+ %i : int = prim::Loop(%6, %12, %i.1)
+ block0(%14 : int, %15 : int) {
+ %c.1 : Tensor = aten::select(%l, %i.1)
+ %17 : int[] = prim::ListConstruct(%4, %5)
+ %18 : Tensor = aten::rand(%17, %2, %i.1, %1)
+ %c : Tensor = aten::add_(%c.1, %18, %7)
+ %i.2 : int = aten::add(%15, %7)
+ %21 : bool = aten::lt(%i.2, %7)
+ -> (%21, %i.2)
+ }
+ return (%b);
+}
self.checkScript(fn, (torch.ones(2, 4, 2), torch.ones(2, 4, 2)))
+ def test_mutable_dce(self):
+ @torch.jit.script
+ def foo():
+ a = torch.rand(2, 3)
+ a += torch.rand(2, 3)
+ b = torch.rand(2, 3)
+ b += torch.rand(2, 3)
+ # b should be cleaned up but not a
+ return a
+
+ self.assertExpectedGraph(foo.graph)
+
+ def test_mutable_dce_block(self):
+ @torch.jit.script
+ def foo():
+ a = torch.rand(2, 3)
+ a += torch.rand(2, 3)
+ b = torch.rand(2, 3)
+ if bool(a > torch.zeros(2, 3)):
+ b += torch.rand(2, 3)
+ a += torch.rand(2, 3)
+ # a should be cleaned up but not b
+ return b
+
+ self.assertExpectedGraph(foo.graph)
+
+ def test_mutable_dce_graph_input(self):
+ @torch.jit.script
+ def foo(a):
+ a += torch.rand(2, 3)
+ # shouldn't clean up `a` even though it's not used in the output
+
+ self.assertExpectedGraph(foo.graph)
+
+ def test_mutable_dce_list(self):
+ @torch.jit.script
+ def foo(a):
+ l = []
+ l.append(a)
+ c = l[0]
+ b = torch.rand(2, 3)
+ c += torch.rand(2, 3)
+ return b
+
+ self.assertExpectedGraph(foo.graph)
+
+ def test_mutable_dce_loop(self):
+ @torch.jit.script
+ def foo(a):
+ l = []
+ l.append(a)
+ i = 0
+ b = torch.rand(2, 3)
+ while i < 1:
+ dead = torch.rand(2, 3)
+ c = l[0]
+ c += torch.rand(2, 3)
+ i += 1
+ return b
+
+ self.assertExpectedGraph(foo.graph)
+
class MnistNet(nn.Module):
def __init__(self):
void validateGraph(const std::shared_ptr<Graph>& graph, onnx_torch::OperatorExportTypes operator_export_type) {
validateBlock(graph->block(), operator_export_type);
- EliminateDeadCode(graph);
+ EliminateDeadCode(graph->block());
}
class EncoderBase {
.def("_jit_pass_onnx_peephole", PeepholeOptimizeONNX)
.def("_jit_pass_fuse", FuseGraph)
.def("_jit_pass_dce", [](std::shared_ptr<Graph>& g) {
- return EliminateDeadCode(g); // overload resolution
+ return EliminateDeadCode(g->block()); // overload resolution
})
.def("_jit_pass_cse", [](std::shared_ptr<Graph>& g) {
return EliminateCommonSubexpression(g); // overload resolution
return outputs_.at(i);
}
-bool Node::isBefore(const Node * n) const {
- if (this == n) {
- return false;
- }
- return !isAfter(n);
-}
+bool Node::isBeforeOrAfter(const Node* n, MoveSide moveSide) const {
+ if (this->owningBlock() == n->owningBlock()) {
+ if (moveSide == MoveSide::BEFORE) {
+ return this->topo_position_ < n->topo_position_;
+ }
-bool Node::isAfter(const Node * n) const {
- JIT_ASSERT(this->owningGraph() == n->owningGraph());
+ if (moveSide == MoveSide::AFTER) {
+ return this->topo_position_ > n->topo_position_;
+ }
- if (this->owningBlock() == n->owningBlock()) {
- return this->topo_position_ > n->topo_position_;
+ JIT_ASSERT(this == n);
+ return false;
}
// These nodes don't share a common block. Traverse the blockchains upward
JIT_ASSERT(rhs->owningBlock());
if (lhs->owningBlock() == rhs->owningBlock()) {
- return lhs->isAfter(rhs);
+ return lhs->isBeforeOrAfter(rhs, moveSide);
}
rhs = rhs->owningBlock()->owningNode();
}
}
// should never reach here, since both nodes are ultimately in the same graph
JIT_ASSERT(false);
+
+}
+
+bool Node::isBefore(const Node * n) const {
+ return isBeforeOrAfter(n, MoveSide::BEFORE);
+}
+
+bool Node::isAfter(const Node * n) const {
+ return isBeforeOrAfter(n, MoveSide::AFTER);
}
Node* Node::insertBefore(Node * n) {
enum class MoveSide { BEFORE, AFTER };
bool tryMove(Node* movePoint, MoveSide moveSide, const AliasDb& aliasDb, bool dryRun);
void move(Node* movePoint, MoveSide moveSide);
+ bool isBeforeOrAfter(const Node* n, MoveSide moveSide) const;
std::pair<Value*, const Argument&> findInput(Symbol name);
void findSchema() const;
const auto & block = *block_;
return block.nodes();
}
+ Node * param_node() {
+ return block_->param_node();
+ }
+ const Node * param_node() const {
+ return block_->param_node();
+ }
Node * return_node() {
return block_->return_node();
}
}
} // namespace
+AliasDb::AliasDb(std::shared_ptr<Graph> graph) : graph_(graph) {
+ analyze(graph_);
+
+ // Build helper indices
+ // NOTE: that these assume that AliasDb is immutable once constructed.
+ // - Alias set -> value mapping
+ for (const auto& pr : valueToAlias_) {
+ const auto value = pr.first;
+ const auto& aliasInfo = pr.second;
+ // We don't support composite types yet
+ JIT_ASSERT(aliasInfo.containedTypes().size() == 0);
+ for (const auto aliasSet : aliasInfo.sets()) {
+ aliasToValue_[aliasSet].insert(value);
+ }
+ }
+ // - Set of all nodes with a wildcard
+ buildWildcardIndex(graph->block());
+}
+
+void AliasDb::buildWildcardIndex(const Block* b) {
+ for (const auto node : b->nodes()) {
+ for (const auto block : node->blocks()) {
+ buildWildcardIndex(block);
+ }
+
+ if (hasWildcardImpl(node)) {
+ wildcardNodes_.insert(node);
+ }
+ }
+}
+
bool AliasDb::hasWildcard(const Node* n) const {
+ return wildcardNodes_.count(n) != 0;
+}
+
+// Does `n` use or write to any wildcard aliases?
+bool AliasDb::hasWildcardImpl(const Node* n) const {
for (const auto input : n->inputs()) {
if (valueToAlias_.count(input) != 0 &&
valueToAlias_.at(input).isWildcard()) {
return false;
}
+bool AliasDb::writesToInputAlias(Node* n) const {
+ std::vector<const Value*> writes;
+ for (const auto input : n->inputs()) {
+ if (writesTo(n, input)) {
+ writes.push_back(input);
+ }
+ }
+ for (const auto output : n->outputs()) {
+ if (writesTo(n, output)) {
+ writes.push_back(output);
+ }
+ }
+
+ // For all writes, check if the written value may alias a graph input
+ return std::any_of(writes.cbegin(), writes.cend(), [&](const Value* v) {
+ const auto& aliasInfo = valueToAlias_.at(v);
+ const auto& aliasSets = aliasInfo.sets();
+
+ // Check every distinct alias set this value belongs to
+ return std::any_of(
+ aliasSets.cbegin(), aliasSets.cend(), [&](const Symbol aliasSet) {
+ return graphInputAliases_.count(aliasSet) != 0;
+ });
+ });
+}
+
std::unordered_set<Node*> AliasDb::getWritersForNode(const Node* n) const {
// Get all alias sets of this node
// ... check the inputs
}
std::cout << "\n";
}
+
+ std::cout << "\n===3. WILDCARD INDEX===\n";
+ for (const auto node : wildcardNodes_) {
+ node->dump();
+ }
}
void AliasDb::analyze(std::shared_ptr<Graph> graph) {
// Assign aliases to the graph's inputs, assuming that all inputs of a given
// type may alias to each other.
- const auto tensorAlias = getFreshAlias();
+ const auto tensorAlias = getFreshAlias(/*isGraphInput=*/true);
// Create a separate alias set for each list type
std::map<TypeKind, Symbol> listTypeAliases;
// Create a separate alias set for each tuple type
containedType = DynamicType::get();
}
if (listTypeAliases.count(containedType->kind()) == 0) {
- listTypeAliases[containedType->kind()] = getFreshAlias();
+ listTypeAliases[containedType->kind()] =
+ getFreshAlias(/*isGraphInput=*/true);
}
addAlias(input, listTypeAliases.at(containedType->kind()));
} else if (inputType->kind() == TypeKind::TupleType) {
auto tupleType = inputType->cast<TupleType>();
if (tupleTypeAliases.count(tupleType) == 0) {
- tupleTypeAliases[tupleType] = getFreshAlias();
+ tupleTypeAliases[tupleType] = getFreshAlias(/*isGraphInput=*/true);
}
addAlias(input, tupleTypeAliases.at(tupleType));
} else {
analyze(subgraphBlock);
- mapAliases(node->outputs(), subgraphBlock->outputs());
+ // TODO(suo): the subgraph outputs and node outputs are NOT NECESSARILY the
+ // same length. Autodifferentiation maybe capture additional outputs in the
+ // subgraph block.
+ JIT_ASSERT(subgraphBlock->outputs().size() >= node->outputs().size());
+ for (size_t i = 0; i < node->outputs().size(); i++) {
+ addAlias(node->outputs()[i], subgraphBlock->outputs()[i]);
+ }
}
// For nodes that generate a fresh value from nothing
}
}
-Symbol AliasDb::getFreshAlias() const {
+Symbol AliasDb::getFreshAlias(bool isGraphInput) {
auto num = std::stoll(latestSymbol_.toUnqualString());
latestSymbol_ = Symbol::fromQualString("alias::" + std::to_string(++num));
+ if (isGraphInput) {
+ graphInputAliases_.insert(latestSymbol_);
+ }
return latestSymbol_;
}
*/
class AliasDb {
public:
- AliasDb(std::shared_ptr<Graph> graph) : graph_(graph) {
- analyze(graph_);
- }
+ explicit AliasDb(std::shared_ptr<Graph> graph);
- // Does `n` contain any wildcard aliases?
+ // Does `n` use or write to any wildcard aliases?
bool hasWildcard(const Node* n) const;
+ const std::unordered_set<const Node*>& getWildcardNodes() const {
+ return wildcardNodes_;
+ }
+
// Does `n` write to any alias sets?
bool hasWrites(Node* n) const;
+ // Does `n` write to a value that may alias one of the graph inputs?
+ bool writesToInputAlias(Node* n) const;
+
// Get all nodes that write to any alias set inputed/outputed by `n`
std::unordered_set<Node*> getWritersForNode(const Node* n) const;
void analyzeChunk(Node* node);
void analyzeBroadcastingChunk(Node* node);
- Symbol getFreshAlias() const;
+ Symbol getFreshAlias(bool isGraphInput = false);
void addAlias(const Value* value, AliasInfo alias);
void addAlias(const Value* value, Symbol alias);
void addAlias(const Value* value, const Value* from);
void mapAliases(at::ArrayRef<Value*> to, at::ArrayRef<Value*> from);
void giveFreshAlias(const Value* value);
+ bool hasUsesAfter(Symbol alias, const Node* n) const;
+ void buildWildcardIndex(const Block* b);
+ bool hasWildcardImpl(const Node* n) const;
bool writesTo(Node* n, const Value* v) const;
std::shared_ptr<Graph> graph_;
- mutable Symbol latestSymbol_ = Symbol::fromQualString("alias::0");
+ Symbol latestSymbol_ = Symbol::fromQualString("alias::0");
std::unordered_map<const Value*, AliasInfo> valueToAlias_;
+ std::unordered_map<Symbol, std::unordered_set<const Value*>> aliasToValue_;
std::unordered_map<Symbol, std::unordered_set<Node*>> aliasToWrites_;
+ std::unordered_set<const Node*> wildcardNodes_;
+ std::unordered_set<Symbol> graphInputAliases_;
};
inline TORCH_API AliasDb AliasAnalysis(std::shared_ptr<Graph> graph) {
for (auto it = block->nodes().begin(); it != block->nodes().end(); ++ it) {
auto node = *it;
if (node->kind() == prim::PythonOp || node->kind() == prim::Print ||
- aliasDb.hasWriters(node)) {
+ aliasDb.hasWriters(node) || aliasDb.hasWildcard(node)) {
// Do NOT have enough information to do CSE on these nodes.
continue;
}
});
bool supported_node = !n->kind().is_onnx() &&
skip_list.count(n->kind()) == 0 && !n->isNondeterministic() &&
- !aliasDb.hasWriters(n);
+ !aliasDb.hasWriters(n) && !aliasDb.hasWildcard(n);
auto run_blocks = [&]() {
if (recurse) {
for (Block * block : n->blocks()) {
-#include "torch/csrc/jit/passes/dead_code_elimination.h"
+#include "dead_code_elimination.h"
+
+#include "torch/csrc/jit/passes/alias_analysis.h"
#include <unordered_map>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
-using bool_memo_type = std::unordered_map<Node*, bool>;
+class DeadCodeEliminator {
+ public:
+ explicit DeadCodeEliminator(std::shared_ptr<Graph> graph)
+ : aliasDb_(AliasAnalysis(graph)) {}
+ DeadCodeEliminator(){};
-bool isMutable(Node* node) {
- if(!node->kind().is_aten())
- return false;
- // onnx export calls EliminateDeadCode but sometimes passes invalid
- // aten operators. So we call maybeSchema so we handle the cases when
- // there is no valid schema for a node
- auto schema = node->maybeSchema();
- return schema && schema->is_mutable();
-}
+ // The algorithm is an inverse mark-and-sweep. Starting from the return node,
+ // we mark "live" nodes that are necessary for the output. Nodes that have
+ // side effects are also marked.
+ void run(Block* block, bool recurse) {
+ // Find the last wildcard in the block. We cannot eliminate any mutable ops
+ // that precede the last wildcard.
+ setLastWildcard();
-bool hasSideEffects(Node * node, bool_memo_type& memo) {
- // FIXME: PythonOp should be treated as having side effects as well!
- // Unfortunately ONNX depends on it getting removed in this pass, so it's not
- // a simple change.
- auto it = memo.find(node);
- if (it != memo.end())
- return it->second;
- bool has_side_effects =
- node->kind() == prim::Print ||
- node->kind() == prim::RaiseException ||
- std::any_of(node->blocks().begin(), node->blocks().end(), [&](Block* b) {
- return std::any_of(b->nodes().begin(), b->nodes().end(), [&](Node* n) {
- return hasSideEffects(n, memo);
- });
- }) || isMutable(node);
-
- memo.emplace(node, has_side_effects);
- return has_side_effects;
-}
+ // Initialize by adding the return node to work list
+ markAndEnqueue(block->return_node());
-void removeDeadIfOutputs(Node* node) {
- if (node->kind() != prim::If)
- return;
- for(size_t i_1 = node->outputs().size(); i_1 > 0; --i_1) {
- size_t i = i_1 - 1;
- if (!node->outputs().at(i)->hasUses()) {
- node->eraseOutput(i);
- for (Block* b : node->blocks()) {
- b->eraseOutput(i);
+ mark(block);
+ sweep(block, recurse);
+ }
+
+ private:
+ void setLastWildcard() {
+ if (!aliasDb_) {
+ return;
+ }
+
+ const auto& wildcards = aliasDb_->getWildcardNodes();
+ if (wildcards.empty()) {
+ return;
+ }
+
+ lastWildcard_ = *wildcards.begin();
+ for (const auto wildcard : wildcards) {
+ if (wildcard->isAfter(*lastWildcard_)) {
+ lastWildcard_ = wildcard;
}
}
}
-}
-void removeDeadLoopOutputs(Node* node) {
- if (node->kind() != prim::Loop)
- return;
- auto loop_body = node->blocks().at(0);
- auto loop_input_offset = 2; // offset of loop carried deps in input list
- auto loop_body_offset = 1; // offset to the loop carried dependencies in block inputs/outputs
-
- for(size_t i_1 = node->outputs().size(); i_1 > 0; --i_1) {
- size_t i = i_1 - 1;
- if (!node->outputs().at(i)->hasUses() &&
- !loop_body->inputs().at(loop_body_offset + i)->hasUses()) {
- node->eraseOutput(i);
- node->removeInput(loop_input_offset + i);
- loop_body->eraseInput(loop_body_offset + i);
- loop_body->eraseOutput(loop_body_offset + i);
+ void mark(Block* block) {
+ // Mark all nodes with side effects.
+ for (auto node : block->nodes()) {
+ if (hasSideEffects(node)) {
+ markAndEnqueue(node);
+ }
+ }
+
+ while (!workQueue_.empty()) {
+ auto node = workQueue_.front();
+ workQueue_.pop_front();
+
+ for (auto subBlock : node->blocks()) {
+ mark(subBlock);
+ }
+
+ // Mark all nodes in this node's blockchain (since owning nodes are
+ // considered live if they contain a live node)
+ if (node->owningBlock() != block) {
+ auto curNode = node;
+ while (curNode) {
+ if (!curNode->owningBlock()) {
+ break;
+ }
+
+ markAndEnqueue(curNode);
+ curNode = curNode->owningBlock()->owningNode();
+ }
+ }
+
+ // Find preceding writers for node, add to work list
+ if (aliasDb_) {
+ for (auto writer : aliasDb_->getWritersForNode(node)) {
+ if (writer->isBefore(node)) {
+ markAndEnqueue(writer);
+ }
+ }
+ }
+
+ // Find producers for all inputs, add to work list
+ for (auto input : node->inputs()) {
+ markAndEnqueue(input->node());
+ }
}
}
-}
-void EliminateDeadCode(Block *block, bool recurse, bool_memo_type& memo) {
- auto nodes = block->nodes().reverse();
- for (auto it = nodes.begin(); it != nodes.end(); it++) {
- auto node = *it;
- // note these occur before the recursion because we want to uncover
- // dead code in the blocks used to calculate the output
- removeDeadIfOutputs(node);
- removeDeadLoopOutputs(node);
- if (recurse) {
- for (Block * block : node->blocks())
- EliminateDeadCode(block, true, memo);
+ // Delete all unmarked nodes.
+ void sweep(Block* block, bool recurse) {
+ auto nodes = block->nodes().reverse();
+ for (auto it = nodes.begin(); it != nodes.end(); it++) {
+ auto node = *it;
+ // note these occur before the recursion because we want to uncover
+ // dead code in the blocks used to calculate the output
+ removeDeadIfOutputs(node);
+ removeDeadLoopOutputs(node);
+ if (recurse) {
+ for (Block* block : node->blocks()) {
+ sweep(block, true);
+ }
+ }
+ // TODO(suo): We shouldn't really have to check whether a node has uses,
+ // since the mark algorithm should do that. But currently, the marking
+ // doesn't reach loop counters in certain cases (see TestScript.test_pass)
+ if (!marked_.count(node) && !node->hasUses()) {
+ it.destroyCurrent();
+ }
}
- if (!node->hasUses() && !hasSideEffects(node, memo))
- it.destroyCurrent();
}
-}
+
+ void markAndEnqueue(Node* n) {
+ if (!marked_.count(n)) {
+ marked_.insert(n);
+ workQueue_.push_back(n);
+ }
+ }
+
+ bool hasUntrackedMutation(Node* node) {
+ if (!aliasDb_) {
+ // If we don't have alias information, all mutable ops have unknown
+ // effects and can't be considered for elimination.
+ if (!node->kind().is_aten()) {
+ return false;
+ }
+ // onnx export calls EliminateDeadCode but sometimes passes invalid
+ // aten operators. So we call maybeSchema so we handle the cases when
+ // there is no valid schema for a node
+ auto schema = node->maybeSchema();
+ return schema && schema->is_mutable();
+ } else {
+ // Otherwise, there are two kinds of nodes with untracked effects:
+ // 1. Nodes that write to a value that may alias the graph inputs (since
+ // the inputs can be used outside the graph).
+ // 2. Anything that could clobber a wildcard value.
+ bool touchesWildcard = false;
+ if (lastWildcard_) {
+ touchesWildcard = aliasDb_->hasWrites(node) &&
+ (node->isBefore(*lastWildcard_) || node == *lastWildcard_);
+ }
+ return aliasDb_->writesToInputAlias(node) || touchesWildcard;
+ }
+ }
+
+ bool hasSideEffects(Node* node) {
+ // FIXME: PythonOp should be treated as having side effects as well!
+ // Unfortunately ONNX depends on it getting removed in this pass, so
+ // it's not a simple change.
+ auto it = memo_.find(node);
+ if (it != memo_.end())
+ return it->second;
+ bool has_side_effects = node->kind() == prim::Print ||
+ node->kind() == prim::RaiseException ||
+ std::any_of(node->blocks().begin(),
+ node->blocks().end(),
+ [&](Block* b) {
+ return std::any_of(
+ b->nodes().begin(), b->nodes().end(), [&](Node* n) {
+ return hasSideEffects(n);
+ });
+ }) ||
+ hasUntrackedMutation(node);
+
+ memo_.emplace(node, has_side_effects);
+ return has_side_effects;
+ }
+
+ void removeDeadIfOutputs(Node* node) {
+ if (node->kind() != prim::If)
+ return;
+
+ for (size_t i_1 = node->outputs().size(); i_1 > 0; --i_1) {
+ size_t i = i_1 - 1;
+ if (!node->outputs().at(i)->hasUses()) {
+ node->eraseOutput(i);
+ for (Block* b : node->blocks()) {
+ b->eraseOutput(i);
+ }
+ }
+ }
+ }
+
+ void removeDeadLoopOutputs(Node* node) {
+ if (node->kind() != prim::Loop)
+ return;
+ auto loop_body = node->blocks().at(0);
+ auto loop_input_offset = 2; // offset of loop carried deps in input list
+ auto loop_body_offset =
+ 1; // offset to the loop carried dependencies in block inputs/outputs
+
+ for (size_t i_1 = node->outputs().size(); i_1 > 0; --i_1) {
+ size_t i = i_1 - 1;
+ if (!node->outputs().at(i)->hasUses() &&
+ !loop_body->inputs().at(loop_body_offset + i)->hasUses()) {
+ node->eraseOutput(i);
+ node->removeInput(loop_input_offset + i);
+ loop_body->eraseInput(loop_body_offset + i);
+ loop_body->eraseOutput(loop_body_offset + i);
+ }
+ }
+ }
+
+ c10::optional<AliasDb> aliasDb_;
+ std::unordered_map<Node*, bool> memo_;
+ std::unordered_set<Node*> marked_;
+ std::list<Node*> workQueue_;
+ c10::optional<const Node*> lastWildcard_;
+
+};
void EliminateDeadCode(const std::shared_ptr<Graph>& graph) {
- bool_memo_type side_effect_memo;
- EliminateDeadCode(graph->block(), true, side_effect_memo);
+ DeadCodeEliminator(graph).run(graph->block(), true);
}
-void EliminateDeadCode(Block *block, bool recurse) {
- bool_memo_type side_effect_memo;
- EliminateDeadCode(block, recurse, side_effect_memo);
+void EliminateDeadCode(Block* block, bool recurse) {
+ DeadCodeEliminator().run(block, recurse);
}
-}} // namespace torch::jit
+} // namespace jit
+} // namespace torch
namespace torch { namespace jit {
+// If given a top-level graph, DCE will construct do alias analysis that allows
+// for "smarter" dead code elimination (we will eliminate mutable ops if we can
+// prove the mutated values are not used). Otherwise, we will not allow DCE to
+// eliminate mutable ops.
+//
+// So, prefer to use the graph version if you can.
TORCH_API void EliminateDeadCode(const std::shared_ptr<Graph>& graph);
TORCH_API void EliminateDeadCode(Block *block, bool recurse=true);
#include "torch/csrc/jit/ir.h"
#include "torch/csrc/jit/passes/dead_code_elimination.h"
+#include "torch/csrc/jit/passes/utils/subgraph_utils.h"
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
namespace {
-bool canRunWithAutograd(Node *node) {
+bool canRunWithAutograd(Node* node) {
return node->kind() != prim::FusionGroup;
}
-void inlineNode(Node *node) {
- WithInsertPoint insert_guard { node };
- Graph * graph = node->owningGraph();
- auto subgraph = node->g(attr::Subgraph);
- std::unordered_map<Value*, Value*> input_map;
+void InlineAutodiffSubgraphs(Block* block, size_t threshold);
+
+graph_node_list::iterator scanNode(Node* node, size_t threshold) {
+ auto next_node = ++node->iterator();
+
+ for (Block* block : node->blocks()) {
+ InlineAutodiffSubgraphs(block, threshold);
+ }
- size_t num_inputs = node->inputs().size();
- JIT_ASSERT(num_inputs == subgraph->inputs().size());
- for (size_t i = 0; i < num_inputs; ++i) {
- input_map[subgraph->inputs()[i]] = node->inputs()[i];
+ if (node->kind() != prim::DifferentiableGraph) {
+ return next_node;
}
- for (Node * subnode : subgraph->nodes()) {
- Node * new_node = graph->insertNode(graph->createClone(subnode, [&](Value * v) { return input_map.at(v); }));
- for (size_t i = 0; i < subnode->outputs().size(); ++i) {
- input_map[subnode->output(i)] = new_node->output(i);
- }
+ auto subgraph = node->g(attr::Subgraph);
+ int64_t subgraph_size =
+ std::distance(subgraph->nodes().begin(), subgraph->nodes().end());
+ if (subgraph_size >= static_cast<int64_t>(threshold)) {
+ return next_node;
}
- size_t num_outputs = node->outputs().size();
- JIT_ASSERT(num_outputs <= subgraph->outputs().size() &&
- num_outputs == static_cast<size_t>(node->i(attr::f_real_outputs)));
- for (size_t i = 0; i < num_outputs; ++i) {
- node->output(i)->replaceAllUsesWith(input_map.at(subgraph->outputs()[i]));
+ if (!std::all_of(
+ subgraph->nodes().begin(),
+ subgraph->nodes().end(),
+ canRunWithAutograd)) {
+ return next_node;
}
+
+ SubgraphUtils::unmergeSubgraph(node);
+ return next_node;
}
-void InlineAutodiffSubgraphs(Block *block, size_t threshold) {
- for (Node * node : block->nodes()) {
- for (Block * block : node->blocks()) {
- InlineAutodiffSubgraphs(block, threshold);
- }
- if (node->kind() != prim::DifferentiableGraph) continue;
- auto subgraph = node->g(attr::Subgraph);
- int64_t subgraph_size = std::distance(subgraph->nodes().begin(), subgraph->nodes().end());
- if (subgraph_size >= static_cast<int64_t>(threshold)) continue;
- if (!std::all_of(subgraph->nodes().begin(), subgraph->nodes().end(), canRunWithAutograd)) continue;
- inlineNode(node);
+void InlineAutodiffSubgraphs(Block* block, size_t threshold) {
+ for (auto it = block->nodes().begin(); it != block->nodes().end();) {
+ it = scanNode(*it, threshold);
}
}
EliminateDeadCode(graph);
}
-}} // namespace torch::jit
+} // namespace jit
+} // namespace torch
void LowerAllTuples(std::shared_ptr<Graph>& graph) {
LowerAllTuples(graph->block());
- EliminateDeadCode(graph);
+ EliminateDeadCode(graph->block());
EnsureNoTuples(graph->block());
}
std::any_of(writers.cbegin(), writers.cend(), [&](const Node* writer) {
return writer->isBefore(node);
});
- if (hasWritersBefore) {
+ if (hasWritersBefore || aliasDb_.hasWildcard(node)) {
// If something could have written to a value used by this node, we can't
// guarantee the result is the same when running it in isolation.
dependsOnMutationMemo_[node] = true;
return {};
}};
-
static const auto any_tensor_type = [](Node* node) -> TensorTypePtr {
for (Value* input : node->inputs()) {
if (auto type = input->type()->cast<TensorType>()) {
// Replace uses of group outputs and destroy the group
const auto subgraphOutputs = subgraph->outputs();
- for (size_t i = 0; i < subgraphOutputs.size(); ++i) {
+ JIT_ASSERT(subgraphOutputs.size() >= subgraphNode->outputs().size());
+ for (size_t i = 0; i < subgraphNode->outputs().size(); ++i) {
const auto outerOutput = innerToOuter.at(subgraphOutputs[i]);
subgraphNode->outputs()[i]->replaceAllUsesWith(outerOutput);
}