with open(grad_path, 'rb') as f:
loaded_grad = f.read()
grad_proto = parse_proto(loaded_grad)
- self.assertTrue(grad_proto == grad_ops[i])
+ self._assertSameOps(grad_proto, grad_ops[i])
shutil.rmtree(temp_dir)
+ def _assertSameOps(self, op1, op2):
+ op1_ = caffe2_pb2.OperatorDef()
+ op1_.CopyFrom(op1)
+ op1_.arg.sort(key=lambda arg: arg.name)
+
+ op2_ = caffe2_pb2.OperatorDef()
+ op2_.CopyFrom(op2)
+ op2_.arg.sort(key=lambda arg: arg.name)
+
+ self.assertEqual(op1_, op2_)
+
def assertSerializedOperatorChecks(
self,
inputs,