add extra info for the auto gen sum ops
authorXianjie Chen <cxj@fb.com>
Wed, 27 Mar 2019 21:52:13 +0000 (14:52 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 27 Mar 2019 21:56:32 +0000 (14:56 -0700)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17934

Reviewed By: iroot900

Differential Revision: D14418689

fbshipit-source-id: 9e11e461001467f0000ea7c355d5b0f0d738fa85

caffe2/python/core.py
caffe2/python/core_gradients_test.py

index 05e1226..2e9161a 100644 (file)
@@ -714,6 +714,8 @@ StopGradient. Op:\n\n{}""".format(op.output[0], str(op)))
 
         return input_name + '_grad'
 
+    IS_AUTO_GEN_SUM_OPS_TAG = "is_auto_gen_sum_ops"
+
     def _SetSumOpsDeviceOption(self, sum_ops, generators):
         # we already checked that device options are consistent so we can just
         # use the first one we find
@@ -724,7 +726,9 @@ StopGradient. Op:\n\n{}""".format(op.output[0], str(op)))
                 if grad_op.HasField('device_option'):
                     for op in sum_ops:
                         op.device_option.CopyFrom(grad_op.device_option)
-                        del op.device_option.extra_info[:]
+                        op.device_option.extra_info.extend([
+                            "{}:1".format(IR.IS_AUTO_GEN_SUM_OPS_TAG)
+                        ])
                 break
 
     def _DisambiguateGradOpOutput(self, grad_op, idx, cnt):
index 75c2689..b6e9817 100644 (file)
@@ -87,8 +87,12 @@ class TestGradientCalculation(test_util.TestCase):
     def assertOperatorListEqual(self, operatorDefList1, operatorDefList2):
         for op in operatorDefList1:
             op.debug_info = ""
+            if op.device_option:
+                del op.device_option.extra_info[:]
         for op in operatorDefList2:
             op.debug_info = ""
+            if op.device_option:
+                del op.device_option.extra_info[:]
         self.assertEqual(operatorDefList1, operatorDefList2)
 
     @given(device_option=st.sampled_from([