self.assertEqual(z, dtype(3))
def testComplexDataTypes(self):
- def sum_func(x, y):
- return x + y
+ def sub_func(x, y):
+ return x - y
for dtype in [np.complex64, np.complex128]:
with self.test_session():
x = constant_op.constant(1 + 1j, dtype=dtype)
- y = constant_op.constant(2 + 2j, dtype=dtype)
- z = self.evaluate(script_ops.py_func(sum_func, [x, y], dtype))
- self.assertEqual(z, dtype(3 + 3j))
+ y = constant_op.constant(2 - 2j, dtype=dtype)
+ z = self.evaluate(script_ops.py_func(sub_func, [x, y], dtype))
+ self.assertEqual(z, dtype(-1 + 3j))
def testSingleType(self):
with self.test_session():