Add AMD codeGen unit tests (#4509)
authorPeter Yeh <petrex@users.noreply.github.com>
Thu, 12 Dec 2019 01:19:42 +0000 (17:19 -0800)
committermasahi <masahi129@gmail.com>
Thu, 12 Dec 2019 01:19:42 +0000 (10:19 +0900)
tests/python/unittest/test_codegen_rocm.py

index 2077372..bba72e0 100644 (file)
 # under the License.
 import tvm
 import numpy as np
+import unittest
 
+tx = tvm.thread_axis("threadIdx.x")
+ty = tvm.thread_axis("threadIdx.y")
+bx = tvm.thread_axis("blockIdx.x")
+by = tvm.thread_axis("blockIdx.y")
 
+@unittest.skipIf(not tvm.rocm(0).exist or not tvm.module.enabled("rocm"), "skip because rocm is not enabled..")
 def test_rocm_cross_thread_reduction():
-    if not tvm.rocm(0).exist or not tvm.module.enabled("rocm"):
-        print("skip because rocm is not enabled..")
-        return
-
     # based on the reduction tutorial
     n = tvm.var("n")
     m = tvm.var("m")
@@ -33,9 +35,8 @@ def test_rocm_cross_thread_reduction():
     ko, ki = s[B].split(B.op.reduce_axis[0], factor=16)
     BF = s.rfactor(B, ki)
     xo, xi = s[B].split(s[B].op.axis[0], factor=32)
-    s[B].bind(xo, tvm.thread_axis("blockIdx.x"))
-    s[B].bind(xi, tvm.thread_axis("threadIdx.y"))
-    tx = tvm.thread_axis("threadIdx.x")
+    s[B].bind(xo, bx) 
+    s[B].bind(xi, ty)
     s[B].bind(s[B].op.reduce_axis[0], tx)
     s[BF].compute_at(s[B], s[B].op.reduce_axis[0])
     s[B].set_store_predicate(tx.var.equal(0))
@@ -49,6 +50,87 @@ def test_rocm_cross_thread_reduction():
     tvm.testing.assert_allclose(
       b.asnumpy(),  np.sum(a.asnumpy(), axis=1), rtol=1e-4)
 
+    
+@unittest.skipIf(not tvm.rocm(0).exist or not tvm.module.enabled("rocm"), "skip because rocm is not enabled..")
+def test_rocm_inf_nan():
+    def check_inf_nan(ctx, n, value, dtype):
+        A = tvm.placeholder((n,), name='A', dtype=dtype)
+        inf_value = tvm.const(value, dtype=dtype)
+        C = tvm.compute((n,), lambda i: inf_value, name='C')
+        s = tvm.create_schedule(C.op)
+        s[C].bind(s[C].op.axis[0], tx)
+        fun = tvm.build(s, [A, C], "rocm")
+        a = tvm.nd.empty((n,), A.dtype, ctx)
+        c = tvm.nd.empty((n,), A.dtype, ctx)
+        # Only need to test compiling here
+        fun(a, c)
+
+    ctx = tvm.rocm(0)
+
+    check_inf_nan(ctx, 1, -float('inf'), 'float32')
+    check_inf_nan(ctx, 1, -float('inf'), 'float64')
+    check_inf_nan(ctx, 1, float('inf'), 'float32')
+    check_inf_nan(ctx, 1, float('inf'), 'float64')
+    check_inf_nan(ctx, 1, float('nan'), 'float32')
+    check_inf_nan(ctx, 1, float('nan'), 'float64')
+
+@unittest.skipIf(not tvm.rocm(0).exist or not tvm.module.enabled("rocm"), "skip because rocm is not enabled..")
+def test_rocm_reducition_binding():
+    k = tvm.reduce_axis((0, 32), 'k')
+    A = tvm.placeholder((96, 32), name='A')
+    B = tvm.compute( (96,), lambda m:
+                     tvm.sum(A[m, k], axis=k),
+                     name='B')
+    s = tvm.create_schedule(B.op)
+
+    s[B].reorder(B.op.reduce_axis[0], B.op.axis[0])
+
+    mo, _ = s[B].split(B.op.axis[0], 32)
+    s[B].bind(mo, bx)
+
+@unittest.skipIf(not tvm.rocm(0).exist or not tvm.module.enabled("rocm"), "skip because rocm is not enabled..")
+def test_rocm_copy():
+
+    def check_rocm(dtype, n):
+        A = tvm.placeholder((n,), name='A', dtype=dtype)
+        ctx = tvm.rocm(0)
+        a_np = np.random.uniform(size=(n,)).astype(A.dtype)
+        a = tvm.nd.empty((n,), A.dtype, ctx).copyfrom(a_np)
+        b_np = a.asnumpy()
+        tvm.testing.assert_allclose(a_np, b_np)
+        tvm.testing.assert_allclose(a_np, a.asnumpy())
+
+    for _ in range(100):
+        dtype = np.random.choice(["float32", "float16", "int8", "int32"])
+        logN = np.random.randint(1, 15)
+        peturb = np.random.uniform(low=0.5, high=1.5)
+        check_rocm(dtype, int(peturb * (2 ** logN)))
+
+@unittest.skipIf(not tvm.rocm(0).exist or not tvm.module.enabled("rocm"), "skip because rocm is not enabled..")
+def test_rocm_vectorize_add():
+    num_thread = 8
+
+    def check_rocm(dtype, n, lanes):
+        A = tvm.placeholder((n,), name='A', dtype="%sx%d" % (dtype, lanes))
+        B = tvm.compute((n,), lambda i: A[i]+tvm.const(1, A.dtype), name='B')
+        s = tvm.create_schedule(B.op)
+        xo, xi = s[B].split(B.op.axis[0], factor=num_thread)
+        s[B].bind(xo, bx)
+        s[B].bind(xi, tx)
+        fun = tvm.build(s, [A, B], "rocm")
+        ctx = tvm.rocm(0)
+        a = tvm.nd.empty((n,), A.dtype, ctx).copyfrom(
+            np.random.uniform(size=(n, lanes)))
+        c = tvm.nd.empty((n,), B.dtype, ctx)
+        fun(a, c)
+        tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + 1)
+
+    check_rocm("float32", 64, 2)
+    check_rocm("float16", 64, 2)
 
 if __name__ == "__main__":
     test_rocm_cross_thread_reduction()
+    test_rocm_inf_nan()
+    test_rocm_reducition_binding()
+    test_rocm_copy()
+    test_rocm_vectorize_add()