Expand tests to include int64 output type.
authorJacques Pienaar <jpienaar@google.com>
Wed, 16 May 2018 19:15:37 +0000 (12:15 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 16 May 2018 19:18:15 +0000 (12:18 -0700)
PiperOrigin-RevId: 196868485

tensorflow/compiler/tests/argminmax_test.py

index ec547e1..ab30aad 100644 (file)
@@ -29,51 +29,70 @@ from tensorflow.python.platform import test
 
 class ArgMinMaxTest(xla_test.XLATestCase):
 
-  def _assertOpOutputMatchesExpected(self, op, inp, expected):
-    """Verifies that 'op' produces 'expected' when fed input 'inp' .
+  def _assertOpOutputMatchesExpected(self, op, axis, output_type, op_input,
+                                     expected):
+    """Verifies that 'op' produces 'expected' when fed input 'op_input' .
 
     Args:
-      op: operator to test
-      inp: numpy input array to use as input to 'op'.
+      op: argmin or argmax operator to test.
+      axis: integer axis to reduce across.
+      output_type: numpy datatype of the output to produce.
+      op_input: numpy input array to use as input to 'op'.
       expected: numpy array representing the expected output of 'op'.
     """
     with self.test_session() as session:
       with self.test_scope():
         pinp = array_ops.placeholder(
-            dtypes.as_dtype(inp.dtype), inp.shape, name="a")
-        output = op(pinp)
-      result = session.run(output, {pinp: inp})
+            dtypes.as_dtype(op_input.dtype), op_input.shape, name="a")
+        output = op(pinp, axis=axis, output_type=output_type)
+      result = session.run(output, {pinp: op_input})
       self.assertAllEqual(result, expected)
 
   def testArgMinMax(self):
     # Complex numbers do not support argmin/argmax.
     minmax_types = set(self.numeric_types) - set(self.complex_types)
-    for dtype in minmax_types:
-      self._assertOpOutputMatchesExpected(
-          lambda x: math_ops.argmax(x, axis=0, output_type=dtypes.int32),
-          np.array([1, 10, 27, 3, 3, 4], dtype=dtype),
-          expected=np.int32(2))
-      self._assertOpOutputMatchesExpected(
-          lambda x: math_ops.argmax(x, axis=0, output_type=dtypes.int32),
-          np.array([[4, 1, 7], [3, 2, 4]], dtype=dtype),
-          expected=np.array([0, 1, 0], dtype=np.int32))
-      self._assertOpOutputMatchesExpected(
-          lambda x: math_ops.argmax(x, axis=1, output_type=dtypes.int32),
-          np.array([[4, 1], [3, 2]], dtype=dtype),
-          expected=np.array([0, 0], dtype=np.int32))
+    for dtype in sorted(minmax_types):
+      # output_type is a numpy data type that is used to specify the desired
+      # output type of the op as well as to convert the Python number to the
+      # array scalar of the type.
+      for output_type in self.int_types:
+        self._assertOpOutputMatchesExpected(
+            math_ops.argmax,
+            axis=0,
+            output_type=output_type,
+            op_input=np.array([1, 10, 27, 3, 3, 4], dtype=dtype),
+            expected=output_type(2))
+        self._assertOpOutputMatchesExpected(
+            math_ops.argmax,
+            axis=0,
+            output_type=output_type,
+            op_input=np.array([[4, 1, 7], [3, 2, 4]], dtype=dtype),
+            expected=np.array([0, 1, 0], dtype=output_type))
+        self._assertOpOutputMatchesExpected(
+            math_ops.argmax,
+            axis=1,
+            output_type=output_type,
+            op_input=np.array([[4, 1], [3, 2]], dtype=dtype),
+            expected=np.array([0, 0], dtype=output_type))
 
-      self._assertOpOutputMatchesExpected(
-          lambda x: math_ops.argmin(x, axis=0, output_type=dtypes.int32),
-          np.array([3, 10, 27, 3, 2, 4], dtype=dtype),
-          expected=np.int32(4))
-      self._assertOpOutputMatchesExpected(
-          lambda x: math_ops.argmin(x, axis=0, output_type=dtypes.int32),
-          np.array([[4, 1, 7], [3, 2, 4]], dtype=dtype),
-          expected=np.array([1, 0, 1], dtype=np.int32))
-      self._assertOpOutputMatchesExpected(
-          lambda x: math_ops.argmin(x, axis=1, output_type=dtypes.int32),
-          np.array([[4, 1], [3, 2]], dtype=dtype),
-          expected=np.array([1, 1], dtype=np.int32))
+        self._assertOpOutputMatchesExpected(
+            math_ops.argmin,
+            axis=0,
+            output_type=output_type,
+            op_input=np.array([3, 10, 27, 3, 2, 4], dtype=dtype),
+            expected=output_type(4))
+        self._assertOpOutputMatchesExpected(
+            math_ops.argmin,
+            axis=0,
+            output_type=output_type,
+            op_input=np.array([[4, 1, 7], [3, 2, 4]], dtype=dtype),
+            expected=np.array([1, 0, 1], dtype=output_type))
+        self._assertOpOutputMatchesExpected(
+            math_ops.argmin,
+            axis=1,
+            output_type=output_type,
+            op_input=np.array([[4, 1], [3, 2]], dtype=dtype),
+            expected=np.array([1, 1], dtype=output_type))
 
 
 if __name__ == "__main__":