return input_name + '_grad'
+ IS_AUTO_GEN_SUM_OPS_TAG = "is_auto_gen_sum_ops"
+
def _SetSumOpsDeviceOption(self, sum_ops, generators):
# we already checked that device options are consistent so we can just
# use the first one we find
if grad_op.HasField('device_option'):
for op in sum_ops:
op.device_option.CopyFrom(grad_op.device_option)
- del op.device_option.extra_info[:]
+ op.device_option.extra_info.extend([
+ "{}:1".format(IR.IS_AUTO_GEN_SUM_OPS_TAG)
+ ])
break
def _DisambiguateGradOpOutput(self, grad_op, idx, cnt):
def assertOperatorListEqual(self, operatorDefList1, operatorDefList2):
for op in operatorDefList1:
op.debug_info = ""
+ if op.device_option:
+ del op.device_option.extra_info[:]
for op in operatorDefList2:
op.debug_info = ""
+ if op.device_option:
+ del op.device_option.extra_info[:]
self.assertEqual(operatorDefList1, operatorDefList2)
@given(device_option=st.sampled_from([