From ad1ebf70827e367c8d0eae8852e11f2289301607 Mon Sep 17 00:00:00 2001 From: Elias Ellison Date: Wed, 27 Mar 2019 16:02:10 -0700 Subject: [PATCH] bug fix for node with writers in create autodiff subgraph (#18491) Summary: Previously we were moving nodes with writers into differentiable subgraphs, without necessarily preserving whether or not they were written to. This can lead to bugs with CSE, which needs that context. I'm not completely sure if there's anything else we can do to be more aggresive here - inline these nodes and not run CSE and just run constant pooling, or possibly something else, but I think we should land this correctness condition first and then possibly think further. Pull Request resolved: https://github.com/pytorch/pytorch/pull/18491 Differential Revision: D14648562 Pulled By: eellison fbshipit-source-id: bc1e444774ccdb708e22f0e06a477a221a231f9e --- test/test_jit.py | 22 ++++++++++++++++++++++ .../csrc/jit/passes/create_autodiff_subgraphs.cpp | 13 +++++++++---- 2 files changed, 31 insertions(+), 4 deletions(-) diff --git a/test/test_jit.py b/test/test_jit.py index 94c3e89..694926b 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -12942,6 +12942,28 @@ class TestAutodiffSubgraphSlicing(JitTestCase): # the same group; they should each be a separate DiffGraph self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 2) + def test_mutation_subgraph_inlining(self): + # cannot move a node which has writers into a differentiable subgraph, + # bc CSE might lose context that it has writers + + def fn(x): + a = x.t() + a = a + 1 + c = x.t() + c = c + 1 + e = a + c + b = a.add_(x) + d = c.add_(x) + return e, b, d + + fn_script = torch.jit.script(fn) + outs1 = fn_script(torch.tensor(0.5, requires_grad=True)) + outs2 = fn(torch.tensor(0.5, requires_grad=True)) + for i in range(len(outs1)): + self.assertEqual(outs1[i], outs2[i]) + graph = fn_script.graph_for(torch.tensor(0.5, requires_grad=True)) + FileCheck().check_not("DifferentiableGraph").run(graph) + class TestCustomOperators(JitTestCase): diff --git a/torch/csrc/jit/passes/create_autodiff_subgraphs.cpp b/torch/csrc/jit/passes/create_autodiff_subgraphs.cpp index 97e36fa..b7a55b8 100644 --- a/torch/csrc/jit/passes/create_autodiff_subgraphs.cpp +++ b/torch/csrc/jit/passes/create_autodiff_subgraphs.cpp @@ -110,7 +110,7 @@ class SubgraphSlicer { return result; } - bool shouldConsiderForMerge(Node* node) { + bool shouldConsiderForMerge(Node* node, const AliasDb& aliasDb) { // if we're already in the process of merging if (node->kind() == prim::DifferentiableGraph) { return true; @@ -118,13 +118,19 @@ class SubgraphSlicer { if (node->kind() == prim::Constant) { return false; } + // when a node which has writers is moved into a subgraph it may lose + // context and CSE could merge it with another node that has writers + // TODO: @eellison Fix problem more generally in CSE, land PR #18500 + if (aliasDb.hasWriters(node)) { + return false; + } return isDifferentiable(node); } std::pair scanNode( Node* consumer, AliasDb& aliasDb) { - if (shouldConsiderForMerge(consumer)) { + if (shouldConsiderForMerge(consumer, aliasDb)) { if (consumer->kind() != prim::DifferentiableGraph) { consumer = SubgraphUtils::createSingletonSubgraph( consumer, prim::DifferentiableGraph); @@ -149,7 +155,7 @@ class SubgraphSlicer { Node* producer, AliasDb& aliasDb) { AT_ASSERT(consumer->kind() == prim::DifferentiableGraph); - bool canMerge = shouldConsiderForMerge(producer) && + bool canMerge = shouldConsiderForMerge(producer, aliasDb) && aliasDb.moveBeforeTopologicallyValid(producer, consumer); if (!canMerge) { @@ -174,6 +180,5 @@ std::vector CreateAutodiffSubgraphs( SubgraphSlicer(graph->block(), graph, threshold).run(diff_nodes); return diff_nodes; } - } // namespace jit } // namespace torch -- 2.7.4