Fix For Requires Grad Infinite Loop (#18361)
authorElias Ellison <eellison@fb.com>
Sun, 24 Mar 2019 21:28:22 +0000 (14:28 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Sun, 24 Mar 2019 21:34:50 +0000 (14:34 -0700)
Summary:
Previously, we would continue to run requires grad on a loop body when the outputs and inputs disagreed. This adds a check so that we don't continue running if the results haven't changed since the last run.

Fix for https://github.com/pytorch/pytorch/issues/18320
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18361

Differential Revision: D14584332

Pulled By: eellison

fbshipit-source-id: 696b225f80a2036318540946428b525985a9e735

test/test_jit.py
torch/csrc/jit/passes/requires_grad_analysis.cpp
torch/csrc/jit/python_ir.cpp

index 17965a2..d5b644e 100644 (file)
@@ -4594,6 +4594,32 @@ a")
 
         test_resize_as()
 
+    def test_requires_grad_loop(self):
+        @torch.jit.script
+        def test(x, y, z):
+            # type: (Tensor, Tensor, int) -> Tensor
+            for _ in range(z):
+                x = y
+            return x
+
+        # x requires grad, y does not
+        # testing that requires grad analysis correctly exits, with its input
+        # to the loop (x) requiring grad and its output to the loop not requiring grad
+        # and the output of the node conservatively setting grad to true
+
+        inps = (torch.tensor(1.0, requires_grad=True), torch.tensor(1), 10)
+        test(*inps)
+
+        graph = test.graph_for(*inps)
+        loop = graph.findNode("prim::Loop")
+        loop_body = next(loop.blocks())
+        loop_inputs = list(loop_body.inputs())
+        loop_outputs = list(loop_body.outputs())
+
+        self.assertTrue(loop_inputs[1].requires_grad())
+        self.assertFalse(loop_outputs[1].requires_grad())
+        self.assertTrue(loop.output().requires_grad())
+
     def test_view_shape_prop(self):
         cu = torch.jit.CompilationUnit('''
         def test_view_shape_prop(a):
index 86d1745..6e0e427 100644 (file)
@@ -1,7 +1,7 @@
+#include <ATen/core/jit_type.h>
 #include <torch/csrc/jit/argument_spec.h>
 #include <torch/csrc/jit/ir.h>
 #include <torch/csrc/jit/operator.h>
-#include <ATen/core/jit_type.h>
 
 #include <vector>
 
@@ -99,15 +99,23 @@ void PropagateRequiresGrad(Node* node) {
         fmap(node->inputs().slice(2), getRequiresGrad);
     std::vector<bool> body_outputs_require(node->outputs().size(), false);
 
-    while (body_inputs_require != body_outputs_require) {
-      body_inputs_require =
+    std::vector<bool> new_body_inputs_require = body_inputs_require;
+    std::vector<bool> new_body_outputs_require = body_outputs_require;
+
+    // continue iterating until the results have converged
+    do {
+      body_inputs_require = new_body_inputs_require;
+      body_outputs_require = new_body_outputs_require;
+
+      new_body_inputs_require =
           bitwiseOr(body_inputs_require, body_outputs_require);
       setRequiresGrad(
-          body->param_node()->outputs().slice(1), body_inputs_require);
+          body->param_node()->outputs().slice(1), new_body_inputs_require);
       PropagateRequiresGrad(body);
-      body_outputs_require =
+      new_body_outputs_require =
           fmap(body->return_node()->inputs().slice(1), getRequiresGrad);
-    }
+    } while (new_body_inputs_require != body_inputs_require &&
+             new_body_outputs_require != body_outputs_require);
 
     setRequiresGrad(node, body_outputs_require);
   } else {
@@ -120,12 +128,10 @@ void PropagateRequiresGrad(Block* block) {
     PropagateRequiresGrad(node);
   }
 }
-
 } // anonymous namespace
 
 void PropagateRequiresGrad(std::shared_ptr<Graph>& graph) {
   PropagateRequiresGrad(graph->block());
 }
-
 } // namespace jit
 } // namespace torch
index cf885bb..8131d6b 100644 (file)
@@ -385,6 +385,7 @@ void initPythonIRBindings(PyObject* module_) {
           })
       .VS(copyMetadata)
       .VS(isTensor)
+      .VS(requires_grad)
       .def("toIValue", [](Value& n) { return toIValue(&n); })
       .def("type", [](Value& v) { return v.type(); });
 #undef VS