From 33018e4e09b16075440ea72a6929b15c7ae670f5 Mon Sep 17 00:00:00 2001 From: Elias Ellison Date: Wed, 19 Dec 2018 10:45:32 -0800 Subject: [PATCH] centralize side effects ops as node method (#15188) Summary: A number of different passes rely on whether a node has side effects. This centralizes the list of side effectful ops in one place. Pull Request resolved: https://github.com/pytorch/pytorch/pull/15188 Differential Revision: D13508438 Pulled By: eellison fbshipit-source-id: 2143e782b787731ce007b6dcd50cbde30e1b8dd0 --- torch/csrc/jit/ir.cpp | 11 +++++++++++ torch/csrc/jit/ir.h | 1 + torch/csrc/jit/passes/common_subexpression_elimination.cpp | 3 +-- torch/csrc/jit/passes/constant_propagation.cpp | 6 +----- torch/csrc/jit/passes/dead_code_elimination.cpp | 4 +--- torch/csrc/jit/passes/shape_analysis.cpp | 9 +++++---- 6 files changed, 20 insertions(+), 14 deletions(-) diff --git a/torch/csrc/jit/ir.cpp b/torch/csrc/jit/ir.cpp index 650d5ef..df038e3 100644 --- a/torch/csrc/jit/ir.cpp +++ b/torch/csrc/jit/ir.cpp @@ -686,6 +686,17 @@ bool Node::isNondeterministic() const { return true; } +bool Node::hasSideEffects() const { + switch (kind_) { + case prim::PythonOp: + case prim::Print: + case prim::RaiseException: + case aten::warn: + return true; + } + return false; +} + // Assign this node a topological position, to facilitate fast isBefore() and // isAfter() queries. Must be called right after a node is inserted into the // node list. diff --git a/torch/csrc/jit/ir.h b/torch/csrc/jit/ir.h index 2a6c9cf..71a5361 100644 --- a/torch/csrc/jit/ir.h +++ b/torch/csrc/jit/ir.h @@ -353,6 +353,7 @@ public: } TORCH_API bool isNondeterministic() const; + TORCH_API bool hasSideEffects () const; // Graphs diff --git a/torch/csrc/jit/passes/common_subexpression_elimination.cpp b/torch/csrc/jit/passes/common_subexpression_elimination.cpp index b96cfaa..cac8f6b 100644 --- a/torch/csrc/jit/passes/common_subexpression_elimination.cpp +++ b/torch/csrc/jit/passes/common_subexpression_elimination.cpp @@ -23,8 +23,7 @@ void EliminateCommonSubexpression( std::unordered_set subexprs; for (auto it = block->nodes().begin(); it != block->nodes().end(); ++ it) { auto node = *it; - if (node->kind() == prim::PythonOp || node->kind() == prim::Print || - node->kind() == aten::warn || node->isNondeterministic() || + if (node->hasSideEffects() || node->isNondeterministic() || aliasDb.hasWriters(node) || aliasDb.hasWildcard(node)) { // Do NOT have enough information to do CSE on these nodes. continue; diff --git a/torch/csrc/jit/passes/constant_propagation.cpp b/torch/csrc/jit/passes/constant_propagation.cpp index a1a6c1a..2446759 100644 --- a/torch/csrc/jit/passes/constant_propagation.cpp +++ b/torch/csrc/jit/passes/constant_propagation.cpp @@ -16,10 +16,6 @@ namespace { std::unordered_set skip_list = { prim::If, prim::Loop, //TODO: handle Loop - prim::Print, - prim::RaiseException, - aten::warn, - prim::PythonOp, //may have side effects prim::Constant, prim::Undefined, prim::NoneGenerator, @@ -125,7 +121,7 @@ void ConstantPropagation(Node* n, const AliasDb& aliasDb, bool recurse) { return v->node()->kind() == prim::Constant; }); bool supported_node = !n->kind().is_onnx() && - skip_list.count(n->kind()) == 0 && !n->isNondeterministic() && + skip_list.count(n->kind()) == 0 && !n->isNondeterministic() && !n->hasSideEffects() && !aliasDb.hasWriters(n) && !aliasDb.hasWildcard(n); auto run_blocks = [&]() { if (recurse) { diff --git a/torch/csrc/jit/passes/dead_code_elimination.cpp b/torch/csrc/jit/passes/dead_code_elimination.cpp index 0167d03..b7d606c 100644 --- a/torch/csrc/jit/passes/dead_code_elimination.cpp +++ b/torch/csrc/jit/passes/dead_code_elimination.cpp @@ -245,9 +245,7 @@ class DeadCodeEliminator { auto it = memo_.find(node); if (it != memo_.end()) return it->second; - bool has_side_effects = node->kind() == prim::Print || - node->kind() == aten::warn || node->kind() == prim::RaiseException || - node->kind() == prim::PythonOp || + bool has_side_effects = node->hasSideEffects() || std::any_of(node->blocks().begin(), node->blocks().end(), [&](Block* b) { diff --git a/torch/csrc/jit/passes/shape_analysis.cpp b/torch/csrc/jit/passes/shape_analysis.cpp index 2d9677a..85465d5 100644 --- a/torch/csrc/jit/passes/shape_analysis.cpp +++ b/torch/csrc/jit/passes/shape_analysis.cpp @@ -434,10 +434,6 @@ class ShapePropagator { } return; } - case prim::PythonOp: - case prim::Print: - case prim::RaiseException: - case aten::warn: case prim::Undefined: { setUnshapedType(node); return; @@ -445,6 +441,11 @@ class ShapePropagator { default: break; // fall-through } + + if (node->hasSideEffects()) { + return; + } + if (node->matches("aten::cat(Tensor[] tensors, int dim) -> Tensor") || node->kind() == prim::FusedConcat) { return PropagateCatShape(node); -- 2.7.4