From 1614536c1312eeba8356bb1e24a8bb51ce7e16c7 Mon Sep 17 00:00:00 2001 From: Brian Patton Date: Wed, 16 May 2018 13:22:53 -0700 Subject: [PATCH] Fix the gradient of reduce_prod for complex dtypes. Fixes #12514 PiperOrigin-RevId: 196878148 --- tensorflow/python/ops/math_grad.py | 4 +++- tensorflow/python/ops/math_grad_test.py | 22 ++++++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py index 02e07dc..563c0b3 100644 --- a/tensorflow/python/ops/math_grad.py +++ b/tensorflow/python/ops/math_grad.py @@ -171,7 +171,9 @@ def _ProdGrad(op, grad): # Calculate product, leaving out the current entry left = math_ops.cumprod(reshaped, axis=0, exclusive=True) right = math_ops.cumprod(reshaped, axis=0, exclusive=True, reverse=True) - y = array_ops.reshape(left * right, permuted_shape) + # For complex inputs, the gradient is in the conjugate direction. + y = array_ops.reshape(math_ops.conj(left) * math_ops.conj(right), + permuted_shape) # Invert the transpose and reshape operations. # Make sure to set the statically known shape information through a reshape. diff --git a/tensorflow/python/ops/math_grad_test.py b/tensorflow/python/ops/math_grad_test.py index 04eeb00..fa47b8f 100644 --- a/tensorflow/python/ops/math_grad_test.py +++ b/tensorflow/python/ops/math_grad_test.py @@ -152,6 +152,28 @@ class ProdGradientTest(test.TestCase): outputs, outputs.get_shape().as_list()) self.assertLess(error, 1e-4) + def testProdGradientComplex(self): + for dtype in dtypes.complex64, dtypes.complex128: + inputs = constant_op.constant([[1 + 3j, 2 - 1j], [3j, 4]], + dtype=dtype) + outputs = math_ops.reduce_prod(inputs) + with self.test_session(): + error = gradient_checker.compute_gradient_error( + inputs, inputs.get_shape().as_list(), + outputs, outputs.get_shape().as_list()) + self.assertLess(error, 1e-4) + + def testProdGradientForNegativeAxisComplex(self): + for dtype in dtypes.complex64, dtypes.complex128: + inputs = constant_op.constant([[1 + 3j, 2 - 1j], [3j, 4]], + dtype=dtype) + outputs = math_ops.reduce_prod(inputs, -1) + with self.test_session(): + error = gradient_checker.compute_gradient_error( + inputs, inputs.get_shape().as_list(), + outputs, outputs.get_shape().as_list()) + self.assertLess(error, 1e-4) + class SegmentMinOrMaxGradientTest(test.TestCase): -- 2.7.4