Add test cases for unique_with_counts_v2
authorYong Tang <yong.tang.github@outlook.com>
Sat, 27 Jan 2018 19:58:54 +0000 (19:58 +0000)
committerYong Tang <yong.tang.github@outlook.com>
Sat, 24 Feb 2018 03:20:45 +0000 (03:20 +0000)
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
tensorflow/python/kernel_tests/unique_op_test.py

index 6366d2e..4498fd9 100644 (file)
@@ -133,6 +133,39 @@ class UniqueWithCountsTest(test.TestCase):
       v = [1 if x[i] == value.decode('ascii') else 0 for i in range(7000)]
       self.assertEqual(count, sum(v))
 
+  def testInt32Axis(self):
+    for dtype in [np.int32, np.int64]:
+      x = np.array([[1, 0, 0], [1, 0, 0], [2, 0, 0]])
+      with self.test_session() as sess:
+        y0, idx0, count0 = gen_array_ops._unique_with_counts_v2(
+            x, axis=np.array([0], dtype))
+        tf_y0, tf_idx0, tf_count0 = sess.run([y0, idx0, count0])
+        y1, idx1, count1 = gen_array_ops._unique_with_counts_v2(
+            x, axis=np.array([1], dtype))
+        tf_y1, tf_idx1, tf_count1 = sess.run([y1, idx1, count1])
+      self.assertAllEqual(tf_y0, np.array([[1, 0, 0], [2, 0, 0]]))
+      self.assertAllEqual(tf_idx0, np.array([0, 0, 1]))
+      self.assertAllEqual(tf_count0, np.array([2, 1]))
+      self.assertAllEqual(tf_y1, np.array([[1, 0], [1, 0], [2, 0]]))
+      self.assertAllEqual(tf_idx1, np.array([0, 1, 1]))
+      self.assertAllEqual(tf_count1, np.array([1, 2]))
+
+  def testInt32V2(self):
+    # This test is only temporary, once V2 is used
+    # by default, the axis will be wrapped to allow `axis=None`.
+    x = np.random.randint(2, high=10, size=7000)
+    with self.test_session() as sess:
+      y, idx, count = gen_array_ops._unique_with_counts_v2(
+          x, axis=np.array([], np.int32))
+      tf_y, tf_idx, tf_count = sess.run([y, idx, count])
+
+    self.assertEqual(len(x), len(tf_idx))
+    self.assertEqual(len(tf_y), len(np.unique(x)))
+    for i in range(len(x)):
+      self.assertEqual(x[i], tf_y[tf_idx[i]])
+    for value, count in zip(tf_y, tf_count):
+      self.assertEqual(count, np.sum(x == value))
+
 
 if __name__ == '__main__':
   test.main()