bug fix for node with writers in create autodiff subgraph (#18491)
authorElias Ellison <eellison@fb.com>
Wed, 27 Mar 2019 23:02:10 +0000 (16:02 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 27 Mar 2019 23:08:03 +0000 (16:08 -0700)
Summary:
Previously we were moving nodes with writers into differentiable subgraphs, without necessarily preserving whether or not they were written to. This can lead to bugs with CSE, which needs that context.

I'm not completely sure if there's anything else we can do to be more aggresive here - inline these nodes and not run CSE and just run constant pooling, or possibly something else, but I think we should land this correctness condition first and then possibly think further.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18491

Differential Revision: D14648562

Pulled By: eellison

fbshipit-source-id: bc1e444774ccdb708e22f0e06a477a221a231f9e

test/test_jit.py
torch/csrc/jit/passes/create_autodiff_subgraphs.cpp

index 94c3e89..694926b 100644 (file)
@@ -12942,6 +12942,28 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
         # the same group; they should each be a separate DiffGraph
         self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 2)
 
+    def test_mutation_subgraph_inlining(self):
+        # cannot move a node which has writers into a differentiable subgraph,
+        # bc CSE might lose context that it has writers
+
+        def fn(x):
+            a = x.t()
+            a = a + 1
+            c = x.t()
+            c = c + 1
+            e = a + c
+            b = a.add_(x)
+            d = c.add_(x)
+            return e, b, d
+
+        fn_script = torch.jit.script(fn)
+        outs1 = fn_script(torch.tensor(0.5, requires_grad=True))
+        outs2 = fn(torch.tensor(0.5, requires_grad=True))
+        for i in range(len(outs1)):
+            self.assertEqual(outs1[i], outs2[i])
+        graph = fn_script.graph_for(torch.tensor(0.5, requires_grad=True))
+        FileCheck().check_not("DifferentiableGraph").run(graph)
+
 
 class TestCustomOperators(JitTestCase):
 
index 97e36fa..b7a55b8 100644 (file)
@@ -110,7 +110,7 @@ class SubgraphSlicer {
     return result;
   }
 
-  bool shouldConsiderForMerge(Node* node) {
+  bool shouldConsiderForMerge(Node* node, const AliasDb& aliasDb) {
     // if we're already in the process of merging
     if (node->kind() == prim::DifferentiableGraph) {
       return true;
@@ -118,13 +118,19 @@ class SubgraphSlicer {
     if (node->kind() == prim::Constant) {
       return false;
     }
+    // when a node which has writers is moved into a subgraph it may lose
+    // context and CSE could merge it with another node that has writers
+    // TODO: @eellison Fix problem more generally in CSE, land PR #18500
+    if (aliasDb.hasWriters(node)) {
+      return false;
+    }
     return isDifferentiable(node);
   }
 
   std::pair<graph_node_list::iterator, bool> scanNode(
       Node* consumer,
       AliasDb& aliasDb) {
-    if (shouldConsiderForMerge(consumer)) {
+    if (shouldConsiderForMerge(consumer, aliasDb)) {
       if (consumer->kind() != prim::DifferentiableGraph) {
         consumer = SubgraphUtils::createSingletonSubgraph(
             consumer, prim::DifferentiableGraph);
@@ -149,7 +155,7 @@ class SubgraphSlicer {
       Node* producer,
       AliasDb& aliasDb) {
     AT_ASSERT(consumer->kind() == prim::DifferentiableGraph);
-    bool canMerge = shouldConsiderForMerge(producer) &&
+    bool canMerge = shouldConsiderForMerge(producer, aliasDb) &&
         aliasDb.moveBeforeTopologicallyValid(producer, consumer);
 
     if (!canMerge) {
@@ -174,6 +180,5 @@ std::vector<Node*> CreateAutodiffSubgraphs(
   SubgraphSlicer(graph->block(), graph, threshold).run(diff_nodes);
   return diff_nodes;
 }
-
 } // namespace jit
 } // namespace torch