From ca962f0f95185e76b1f43a4423d28986ee0da191 Mon Sep 17 00:00:00 2001 From: Elias Ellison Date: Sun, 24 Mar 2019 14:28:22 -0700 Subject: [PATCH] Fix For Requires Grad Infinite Loop (#18361) Summary: Previously, we would continue to run requires grad on a loop body when the outputs and inputs disagreed. This adds a check so that we don't continue running if the results haven't changed since the last run. Fix for https://github.com/pytorch/pytorch/issues/18320 Pull Request resolved: https://github.com/pytorch/pytorch/pull/18361 Differential Revision: D14584332 Pulled By: eellison fbshipit-source-id: 696b225f80a2036318540946428b525985a9e735 --- test/test_jit.py | 26 ++++++++++++++++++++++++ torch/csrc/jit/passes/requires_grad_analysis.cpp | 22 ++++++++++++-------- torch/csrc/jit/python_ir.cpp | 1 + 3 files changed, 41 insertions(+), 8 deletions(-) diff --git a/test/test_jit.py b/test/test_jit.py index 17965a2..d5b644e 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -4594,6 +4594,32 @@ a") test_resize_as() + def test_requires_grad_loop(self): + @torch.jit.script + def test(x, y, z): + # type: (Tensor, Tensor, int) -> Tensor + for _ in range(z): + x = y + return x + + # x requires grad, y does not + # testing that requires grad analysis correctly exits, with its input + # to the loop (x) requiring grad and its output to the loop not requiring grad + # and the output of the node conservatively setting grad to true + + inps = (torch.tensor(1.0, requires_grad=True), torch.tensor(1), 10) + test(*inps) + + graph = test.graph_for(*inps) + loop = graph.findNode("prim::Loop") + loop_body = next(loop.blocks()) + loop_inputs = list(loop_body.inputs()) + loop_outputs = list(loop_body.outputs()) + + self.assertTrue(loop_inputs[1].requires_grad()) + self.assertFalse(loop_outputs[1].requires_grad()) + self.assertTrue(loop.output().requires_grad()) + def test_view_shape_prop(self): cu = torch.jit.CompilationUnit(''' def test_view_shape_prop(a): diff --git a/torch/csrc/jit/passes/requires_grad_analysis.cpp b/torch/csrc/jit/passes/requires_grad_analysis.cpp index 86d1745..6e0e427 100644 --- a/torch/csrc/jit/passes/requires_grad_analysis.cpp +++ b/torch/csrc/jit/passes/requires_grad_analysis.cpp @@ -1,7 +1,7 @@ +#include #include #include #include -#include #include @@ -99,15 +99,23 @@ void PropagateRequiresGrad(Node* node) { fmap(node->inputs().slice(2), getRequiresGrad); std::vector body_outputs_require(node->outputs().size(), false); - while (body_inputs_require != body_outputs_require) { - body_inputs_require = + std::vector new_body_inputs_require = body_inputs_require; + std::vector new_body_outputs_require = body_outputs_require; + + // continue iterating until the results have converged + do { + body_inputs_require = new_body_inputs_require; + body_outputs_require = new_body_outputs_require; + + new_body_inputs_require = bitwiseOr(body_inputs_require, body_outputs_require); setRequiresGrad( - body->param_node()->outputs().slice(1), body_inputs_require); + body->param_node()->outputs().slice(1), new_body_inputs_require); PropagateRequiresGrad(body); - body_outputs_require = + new_body_outputs_require = fmap(body->return_node()->inputs().slice(1), getRequiresGrad); - } + } while (new_body_inputs_require != body_inputs_require && + new_body_outputs_require != body_outputs_require); setRequiresGrad(node, body_outputs_require); } else { @@ -120,12 +128,10 @@ void PropagateRequiresGrad(Block* block) { PropagateRequiresGrad(node); } } - } // anonymous namespace void PropagateRequiresGrad(std::shared_ptr& graph) { PropagateRequiresGrad(graph->block()); } - } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/python_ir.cpp b/torch/csrc/jit/python_ir.cpp index cf885bb..8131d6b 100644 --- a/torch/csrc/jit/python_ir.cpp +++ b/torch/csrc/jit/python_ir.cpp @@ -385,6 +385,7 @@ void initPythonIRBindings(PyObject* module_) { }) .VS(copyMetadata) .VS(isTensor) + .VS(requires_grad) .def("toIValue", [](Value& n) { return toIValue(&n); }) .def("type", [](Value& v) { return v.type(); }); #undef VS -- 2.7.4