From d74b11ce0ed33a13bb6befff03c83a3d115255b1 Mon Sep 17 00:00:00 2001 From: Xianjie Chen Date: Wed, 27 Mar 2019 14:52:13 -0700 Subject: [PATCH] add extra info for the auto gen sum ops Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17934 Reviewed By: iroot900 Differential Revision: D14418689 fbshipit-source-id: 9e11e461001467f0000ea7c355d5b0f0d738fa85 --- caffe2/python/core.py | 6 +++++- caffe2/python/core_gradients_test.py | 4 ++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/caffe2/python/core.py b/caffe2/python/core.py index 05e1226..2e9161a 100644 --- a/caffe2/python/core.py +++ b/caffe2/python/core.py @@ -714,6 +714,8 @@ StopGradient. Op:\n\n{}""".format(op.output[0], str(op))) return input_name + '_grad' + IS_AUTO_GEN_SUM_OPS_TAG = "is_auto_gen_sum_ops" + def _SetSumOpsDeviceOption(self, sum_ops, generators): # we already checked that device options are consistent so we can just # use the first one we find @@ -724,7 +726,9 @@ StopGradient. Op:\n\n{}""".format(op.output[0], str(op))) if grad_op.HasField('device_option'): for op in sum_ops: op.device_option.CopyFrom(grad_op.device_option) - del op.device_option.extra_info[:] + op.device_option.extra_info.extend([ + "{}:1".format(IR.IS_AUTO_GEN_SUM_OPS_TAG) + ]) break def _DisambiguateGradOpOutput(self, grad_op, idx, cnt): diff --git a/caffe2/python/core_gradients_test.py b/caffe2/python/core_gradients_test.py index 75c2689..b6e9817 100644 --- a/caffe2/python/core_gradients_test.py +++ b/caffe2/python/core_gradients_test.py @@ -87,8 +87,12 @@ class TestGradientCalculation(test_util.TestCase): def assertOperatorListEqual(self, operatorDefList1, operatorDefList2): for op in operatorDefList1: op.debug_info = "" + if op.device_option: + del op.device_option.extra_info[:] for op in operatorDefList2: op.debug_info = "" + if op.device_option: + del op.device_option.extra_info[:] self.assertEqual(operatorDefList1, operatorDefList2) @given(device_option=st.sampled_from([ -- 2.7.4