Extend reduce_ops test to integers.
authorJacques Pienaar <jpienaar@google.com>
Wed, 14 Mar 2018 17:39:41 +0000 (10:39 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 14 Mar 2018 17:44:11 +0000 (10:44 -0700)
PiperOrigin-RevId: 189049525

tensorflow/compiler/tests/reduce_ops_test.py

index 965fdf6..2c084b0 100644 (file)
@@ -18,6 +18,7 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import functools
 import numpy as np
 
 from tensorflow.compiler.tests.xla_test import XLATestCase
@@ -30,8 +31,13 @@ from tensorflow.python.platform import googletest
 
 class ReduceOpsTest(XLATestCase):
 
-  def _testReduction(self, tf_reduce_fn, np_reduce_fn, dtype, test_inputs,
-                     rtol=1e-4, atol=1e-4):
+  def _testReduction(self,
+                     tf_reduce_fn,
+                     np_reduce_fn,
+                     dtype,
+                     test_inputs,
+                     rtol=1e-4,
+                     atol=1e-4):
     """Tests that the output of 'tf_reduce_fn' matches numpy's output."""
 
     for test_input in test_inputs:
@@ -41,16 +47,16 @@ class ReduceOpsTest(XLATestCase):
           index = array_ops.placeholder(dtypes.int32)
           out = tf_reduce_fn(a, index)
         result = sess.run(out, {a: test_input, index: [0]})
-        self.assertAllClose(result, np_reduce_fn(test_input, axis=0),
-                            rtol=rtol, atol=atol)
+        self.assertAllClose(
+            result, np_reduce_fn(test_input, axis=0), rtol=rtol, atol=atol)
 
         result = sess.run(out, {a: test_input, index: [1]})
-        self.assertAllClose(result, np_reduce_fn(test_input, axis=1),
-                            rtol=rtol, atol=atol)
+        self.assertAllClose(
+            result, np_reduce_fn(test_input, axis=1), rtol=rtol, atol=atol)
 
         result = sess.run(out, {a: test_input, index: [-1]})
-        self.assertAllClose(result, np_reduce_fn(test_input, axis=1),
-                            rtol=rtol, atol=atol)
+        self.assertAllClose(
+            result, np_reduce_fn(test_input, axis=1), rtol=rtol, atol=atol)
 
         with self.assertRaisesWithPredicateMatch(
             errors_impl.InvalidArgumentError, 'Invalid reduction dim'):
@@ -60,7 +66,7 @@ class ReduceOpsTest(XLATestCase):
             errors_impl.InvalidArgumentError, 'Invalid reduction dim'):
           sess.run(out, {a: test_input, index: [2]})
 
-  FLOAT_DATA = [
+  REAL_DATA = [
       np.zeros(shape=(2, 0)),
       np.zeros(shape=(0, 30)),
       np.arange(1, 7).reshape(2, 3),
@@ -74,7 +80,7 @@ class ReduceOpsTest(XLATestCase):
       np.arange(-14, -2, dtype=np.float32).view(np.complex64).reshape(2, 3),
       np.arange(-4, 8, dtype=np.float32).view(np.complex64).reshape(2, 3),
   ]
-  NONEMPTY_FLOAT_DATA = [x for x in FLOAT_DATA if np.size(x) > 0]
+  NONEMPTY_REAL_DATA = [x for x in REAL_DATA if np.size(x) > 0]
   NONEMPTY_COMPLEX_DATA = [x for x in COMPLEX_DATA if np.size(x) > 0]
   BOOL_DATA = [
       np.array([], dtype=np.bool).reshape(2, 0),
@@ -83,8 +89,7 @@ class ReduceOpsTest(XLATestCase):
   ]
 
   def testReduceSumF32(self):
-    self._testReduction(math_ops.reduce_sum, np.sum, np.float32,
-                        self.FLOAT_DATA)
+    self._testReduction(math_ops.reduce_sum, np.sum, np.float32, self.REAL_DATA)
 
   def testReduceSumC64(self):
     self._testReduction(math_ops.reduce_sum, np.sum, np.complex64,
@@ -92,7 +97,7 @@ class ReduceOpsTest(XLATestCase):
 
   def testReduceProdF32(self):
     self._testReduction(math_ops.reduce_prod, np.prod, np.float32,
-                        self.FLOAT_DATA)
+                        self.REAL_DATA)
 
   def testReduceProdC64(self):
     self._testReduction(math_ops.reduce_prod, np.prod, np.complex64,
@@ -100,31 +105,44 @@ class ReduceOpsTest(XLATestCase):
 
   def testReduceMin(self):
 
-    def reference_min(inp, axis):
+    def reference_min(dtype, inp, axis):
       """Wrapper around np.amin that returns +infinity for an empty input."""
       if inp.shape[axis] == 0:
-        return np.full(inp.shape[0:axis] + inp.shape[axis + 1:], float('inf'))
+        if np.issubdtype(dtype, np.floating):
+          return np.full(inp.shape[0:axis] + inp.shape[axis + 1:], float('inf'))
+        return np.full(inp.shape[0:axis] + inp.shape[axis + 1:],
+                       np.iinfo(dtype).max)
       return np.amin(inp, axis)
 
-    self._testReduction(math_ops.reduce_min, reference_min, np.float32,
-                        self.FLOAT_DATA)
+    for dtype in set(self.all_types).intersection(
+        [np.float32, np.int32, np.int64]):
+      self._testReduction(math_ops.reduce_min,
+                          functools.partial(reference_min, dtype), dtype,
+                          self.REAL_DATA)
 
   def testReduceMax(self):
 
-    def reference_max(inp, axis):
+    def reference_max(dtype, inp, axis):
       """Wrapper around np.amax that returns -infinity for an empty input."""
       if inp.shape[axis] == 0:
-        return np.full(inp.shape[0:axis] + inp.shape[axis + 1:], float('-inf'))
+        if np.issubdtype(dtype, np.floating):
+          return np.full(inp.shape[0:axis] + inp.shape[axis + 1:],
+                         float('-inf'))
+        return np.full(inp.shape[0:axis] + inp.shape[axis + 1:],
+                       np.iinfo(dtype).min)
       return np.amax(inp, axis)
 
-    self._testReduction(math_ops.reduce_max, reference_max, np.float32,
-                        self.FLOAT_DATA)
+    for dtype in set(self.all_types).intersection(
+        [np.float32, np.int32, np.int64]):
+      self._testReduction(math_ops.reduce_max,
+                          functools.partial(reference_max, dtype), dtype,
+                          self.REAL_DATA)
 
   def testReduceMeanF32(self):
     # TODO(phawkins): mean on XLA currently returns 0 instead of NaN when
     # reducing across zero inputs.
     self._testReduction(math_ops.reduce_mean, np.mean, np.float32,
-                        self.NONEMPTY_FLOAT_DATA)
+                        self.NONEMPTY_REAL_DATA)
 
   def testReduceMeanC64(self):
     self._testReduction(math_ops.reduce_mean, np.mean, np.complex64,