fix auto grad summing for IfOp where intermediate output needs renaming (#14772)
authorYiming Wu <wyiming@fb.com>
Sun, 9 Dec 2018 16:23:36 +0000 (08:23 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Sun, 9 Dec 2018 16:26:46 +0000 (08:26 -0800)
Summary:
fix auto grad summing for IfOp where intermediate output needs renaming.

Bug before this diff:
- we only renames the output of IfOp without changing the subnet ops output
- this results in blob not found error

the unittest provides an example
this diff fix that for IfOp
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14772

Differential Revision: D13327090

Pulled By: harouwu

fbshipit-source-id: ec40ee88526ace3619c54551e223dd71158a02f8

caffe2/python/control_ops_grad.py
caffe2/python/control_ops_grad_test.py [new file with mode: 0644]
caffe2/python/core.py

index e004507..fa6753a 100644 (file)
@@ -683,3 +683,19 @@ def _prepare_gradient_if_op(
     del gradient_if_def.control_input[:]
     gradient_if_def.is_gradient_op = True
     return gradient_if_def
+
+
+def disambiguate_grad_if_op_output(grad_op, idx, new_grad_output):
+    then_net = _get_net_argument(grad_op, "then_net")
+    old_grad_out_match = grad_op.output[idx]
+    for op in then_net.op:
+        for i, out in enumerate(op.output):
+            if out == old_grad_out_match:
+                op.output[i] = new_grad_output
+    else_net = _get_net_argument(grad_op, "else_net")
+    if else_net:
+        for op in else_net.op:
+            for i, out in enumerate(op.output):
+                if out == old_grad_out_match:
+                    op.output[i] = new_grad_output
+    grad_op.output[idx] = new_grad_output
diff --git a/caffe2/python/control_ops_grad_test.py b/caffe2/python/control_ops_grad_test.py
new file mode 100644 (file)
index 0000000..2d11328
--- /dev/null
@@ -0,0 +1,40 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+from caffe2.python import core, test_util, workspace
+from caffe2.python.control_ops_grad import disambiguate_grad_if_op_output
+from caffe2.python.model_helper import ModelHelper
+import numpy as np
+
+
+class TestControl(test_util.TestCase):
+    def test_disambiguate_grad_if_op_output(self):
+        workspace.FeedBlob("cond", np.array(True))
+        workspace.FeedBlob("then_grad", np.array(1))
+        workspace.FeedBlob("else_grad", np.array(2))
+
+        then_model = ModelHelper(name="then_test_model")
+        then_model.net.Copy("then_grad", "input_grad")
+
+        else_model = ModelHelper(name="else_test_model")
+        else_model.net.Copy("else_grad", "else_temp_grad")
+        else_model.net.Copy("else_temp", "input_grad")
+
+        # to BuildGradientGenerators, in forward pass, we need else temp
+        # as one of the output. Which later on results in a grad op like this:
+        grad_op = core.CreateOperator(
+            "If",
+            ["cond", "then_grad", "else_grad"],
+            ["input_grad", "else_temp_grad"],
+            then_net=then_model.net.Proto(),
+            else_net=else_model.net.Proto(),
+        )
+
+        # in certain cases, another branch of the net also generates input_grad
+        # and we call _DisambiguateGradOpOutput in core.py
+        new_grad_output = "input_grad" + "_autosplit_" + "0"
+        disambiguate_grad_if_op_output(grad_op, 0, new_grad_output)
+        self.assertEqual(grad_op.output[0], new_grad_output)
+        self.assertEqual(grad_op.arg[1].n.op[1].output[0], new_grad_output)
index 6cab923..9011fed 100644 (file)
@@ -14,7 +14,7 @@ from six import binary_type, string_types, text_type
 from caffe2.proto import caffe2_pb2
 from caffe2.python import scope, utils, workspace
 from caffe2.python.control_ops_grad import \
-    gen_do_gradient, gen_if_gradient, gen_while_gradient
+    gen_do_gradient, gen_if_gradient, gen_while_gradient, disambiguate_grad_if_op_output
 
 import caffe2.python._import_c_extension as C
 
@@ -725,8 +725,12 @@ StopGradient. Op:\n\n{}""".format(op.output[0], str(op)))
                 break
 
     def _DisambiguateGradOpOutput(self, grad_op, idx, cnt):
-        grad_op.output[idx] = (
+        new_grad_output = (
             '_' + grad_op.output[idx] + '_autosplit_{}'.format(cnt))
+        if grad_op.type == "If":
+            disambiguate_grad_if_op_output(grad_op, idx, new_grad_output)
+        else:
+            grad_op.output[idx] = new_grad_output
         return grad_op.output[idx], cnt + 1
 
     def _CheckSumOpsConflict(self, out_base_name, g):