centralize side effects ops as node method (#15188)
authorElias Ellison <eellison@fb.com>
Wed, 19 Dec 2018 18:45:32 +0000 (10:45 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 19 Dec 2018 18:52:54 +0000 (10:52 -0800)
Summary:
A number of different passes rely on whether a node has side effects. This centralizes the list of side effectful ops in one place.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15188

Differential Revision: D13508438

Pulled By: eellison

fbshipit-source-id: 2143e782b787731ce007b6dcd50cbde30e1b8dd0

torch/csrc/jit/ir.cpp
torch/csrc/jit/ir.h
torch/csrc/jit/passes/common_subexpression_elimination.cpp
torch/csrc/jit/passes/constant_propagation.cpp
torch/csrc/jit/passes/dead_code_elimination.cpp
torch/csrc/jit/passes/shape_analysis.cpp

index 650d5ef..df038e3 100644 (file)
@@ -686,6 +686,17 @@ bool Node::isNondeterministic() const {
   return true;
 }
 
+bool Node::hasSideEffects() const {
+  switch (kind_) {
+    case prim::PythonOp:
+    case prim::Print:
+    case prim::RaiseException:
+    case aten::warn:
+      return true;
+  }
+  return false;
+}
+
 // Assign this node a topological position, to facilitate fast isBefore() and
 // isAfter() queries. Must be called right after a node is inserted into the
 // node list.
index 2a6c9cf..71a5361 100644 (file)
@@ -353,6 +353,7 @@ public:
   }
 
   TORCH_API bool isNondeterministic() const;
+  TORCH_API bool hasSideEffects () const;
 
   // Graphs
 
index b96cfaa..cac8f6b 100644 (file)
@@ -23,8 +23,7 @@ void EliminateCommonSubexpression(
   std::unordered_set<Node*, HashNode, EqualNode> subexprs;
   for (auto it = block->nodes().begin(); it != block->nodes().end(); ++ it) {
     auto node = *it;
-    if (node->kind() == prim::PythonOp || node->kind() == prim::Print ||
-        node->kind() == aten::warn || node->isNondeterministic() ||
+    if (node->hasSideEffects() || node->isNondeterministic() ||
         aliasDb.hasWriters(node) || aliasDb.hasWildcard(node)) {
       // Do NOT have enough information to do CSE on these nodes.
       continue;
index a1a6c1a..2446759 100644 (file)
@@ -16,10 +16,6 @@ namespace {
 std::unordered_set<Symbol> skip_list = {
   prim::If,
   prim::Loop, //TODO: handle Loop
-  prim::Print,
-  prim::RaiseException,
-  aten::warn,
-  prim::PythonOp, //may have side effects
   prim::Constant,
   prim::Undefined,
   prim::NoneGenerator,
@@ -125,7 +121,7 @@ void ConstantPropagation(Node* n, const AliasDb& aliasDb, bool recurse) {
         return v->node()->kind() == prim::Constant;
       });
   bool supported_node = !n->kind().is_onnx() &&
-      skip_list.count(n->kind()) == 0 && !n->isNondeterministic() &&
+      skip_list.count(n->kind()) == 0 && !n->isNondeterministic() && !n->hasSideEffects() &&
       !aliasDb.hasWriters(n) && !aliasDb.hasWildcard(n);
   auto run_blocks = [&]() {
     if (recurse) {
index 0167d03..b7d606c 100644 (file)
@@ -245,9 +245,7 @@ class DeadCodeEliminator {
     auto it = memo_.find(node);
     if (it != memo_.end())
       return it->second;
-    bool has_side_effects = node->kind() == prim::Print ||
-        node->kind() == aten::warn || node->kind() == prim::RaiseException ||
-        node->kind() == prim::PythonOp ||
+    bool has_side_effects = node->hasSideEffects() ||
         std::any_of(node->blocks().begin(),
                     node->blocks().end(),
                     [&](Block* b) {
index 2d9677a..85465d5 100644 (file)
@@ -434,10 +434,6 @@ class ShapePropagator {
         }
         return;
       }
-      case prim::PythonOp:
-      case prim::Print:
-      case prim::RaiseException:
-      case aten::warn:
       case prim::Undefined: {
         setUnshapedType(node);
         return;
@@ -445,6 +441,11 @@ class ShapePropagator {
       default:
         break; // fall-through
     }
+
+    if (node->hasSideEffects()) {
+      return;
+    }
+
     if (node->matches("aten::cat(Tensor[] tensors, int dim) -> Tensor")
        || node->kind() == prim::FusedConcat) {
       return PropagateCatShape(node);