ROCm: Add warp shuffles and enable reductions (#5727)
authorThomas Viehmann <tv.code@beamnet.de>
Fri, 5 Jun 2020 09:49:37 +0000 (11:49 +0200)
committerGitHub <noreply@github.com>
Fri, 5 Jun 2020 09:49:37 +0000 (18:49 +0900)
Thank you @masahi and @wpan11nv for the feedback

src/target/llvm/intrin_rule_rocm.cc
src/target/target.cc
src/tir/transforms/lower_thread_allreduce.cc
tests/python/integration/test_reduce.py
tests/python/unittest/test_target_codegen_cuda.py
tests/python/unittest/test_target_codegen_rocm.py
tests/python/unittest/test_tir_transform_lower_warp_memory.py
topi/python/topi/cuda/softmax.py

index 52447a1..22af9f1 100644 (file)
@@ -24,6 +24,7 @@
 
 #include <tvm/runtime/registry.h>
 #include <tvm/tir/expr.h>
+#include <tvm/tir/op.h>
 
 #include <sstream>
 
@@ -40,8 +41,59 @@ inline void DispatchExternOCML(const TVMArgs& args, TVMRetValue* rv) {
   *rv = CallNode::make(call->dtype, intrinsic_name.str(), call->args, CallNode::PureExtern);
 }
 
+inline void DispatchShuffle(const TVMArgs& targs, TVMRetValue* rv) {
+  PrimExpr e_call = targs[0];
+  using namespace tir;
+  const CallNode* call = e_call.as<CallNode>();
+  CHECK(call != nullptr);
+  CHECK_EQ(call->args.size(), 5);  // mask, value, warp_id, width, warp_size
+  PrimExpr var = call->args[1];
+  CHECK_EQ(var.dtype().bits(), 32);
+
+  // get own lane in self (__lane_id)
+  PrimExpr minus_one = tir::make_const(DataType::Int(32), -1);
+  PrimExpr zero = tir::make_zero(DataType::Int(32));
+  PrimExpr lo = CallNode::make(DataType::Int(32), "llvm.amdgcn.mbcnt.lo", {minus_one, zero},
+                               CallNode::PureExtern);
+  PrimExpr self = CallNode::make(DataType::Int(32), "llvm.amdgcn.mbcnt.hi", {minus_one, lo},
+                                 CallNode::PureExtern);
+
+  // compute lane to get from
+  PrimExpr width = call->args[3];
+  PrimExpr index;
+  if (call->name == "tvm_warp_shuffle") {
+    PrimExpr src_lane = call->args[2];
+    index = src_lane + (self & ~(width - 1));
+  } else if (call->name == "tvm_warp_shuffle_up") {
+    PrimExpr delta = call->args[2];
+    index = self - delta;
+    index = SelectNode::make(index < (self & ~(width - 1)), self, index);
+  } else {
+    CHECK_EQ(call->name, "tvm_warp_shuffle_down");
+    PrimExpr delta = call->args[2];
+    index = self + delta;
+    index = SelectNode::make((self & (width - 1)) + delta >= width, self, index);
+  }
+  PrimExpr res = CallNode::make(var.dtype(), "llvm.amdgcn.ds.bpermute", {index << 2, var},
+                                CallNode::PureExtern);
+  *rv = res;
+}
+
 namespace llvm {
 
+// dummy because we don't have the activemask
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.tvm_warp_activemask")
+    .set_body([](const TVMArgs& targs, TVMRetValue* rv) {
+      PrimExpr zero = tir::make_zero(DataType::Int(32));
+      *rv = zero;
+    });
+
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.tvm_warp_shuffle").set_body(DispatchShuffle);
+
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.tvm_warp_shuffle_up").set_body(DispatchShuffle);
+
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.tvm_warp_shuffle_down").set_body(DispatchShuffle);
+
 TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.floor").set_body(DispatchExternOCML);
 
 TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.ceil").set_body(DispatchExternOCML);
index f3ade8d..2104c2e 100644 (file)
@@ -98,8 +98,9 @@ Target CreateTarget(const std::string& target_name, const std::vector<std::strin
     // For now assume rocm schedule for opencl
     if (target_name == "opencl") {
       t->device_type = kDLOpenCL;
-    } else {
+    } else {  // rocm
       t->device_type = kDLROCM;
+      t->thread_warp_size = 64;
     }
     t->keys_array.push_back(target_name);
     t->keys_array.push_back("gpu");
index 127b012..de86647 100644 (file)
@@ -196,7 +196,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
     //
     // Allocate reduction vars v[i], i = 0..size-1
     //
-    // for offset from 16 to 1 by 2
+    // for offset from WARP_SIZE to 1 by 2
     //
     //   a    <- load(v[i])
     //   b    <- shuffle_down(load(v[i], offset))
@@ -244,7 +244,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
       }
 
       // Emit reductions within a warp.
-      for (int offset = 16; offset > 0; offset /= 2) {
+      for (int offset = warp_size_ / 2; offset > 0; offset /= 2) {
         // Load reduction values, no synchronization needed.
         Array<PrimExpr> a, b;
         for (size_t i = 0; i < size; ++i) {
@@ -478,9 +478,20 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
   // the warp size.
   //
   // TODO(tvm-team) reduction with a sub-warp of 8 or 16 threads.
+  // Note: The ROCm backend will only have warp reductions for now.
+  // Also, the warp/wavefront size differs (64 on rocm, 32 on cuda).
   bool is_warp_reduction(const std::vector<DataType>& types) const {
     // Only cuda target supports warp reductions.
-    if (target_->target_name != "cuda") return false;
+    if ((target_->target_name != "cuda") && (target_->target_name != "rocm")) return false;
+
+    // rocm only supports 32 bit operands for shuffling at the moment
+    if ((target_->target_name == "rocm") &&
+        (std::any_of(types.begin(), types.end(), [](DataType ty) {
+          if (ty.is_vector()) return true;
+          return ty.bits() != 32;
+        }))) {
+      return false;
+    }
 
     // Supported types:
     // {u}int, {u}long, {u}long long, float, double, half/half2
index 7ac3496..c5d9d08 100644 (file)
@@ -65,6 +65,7 @@ def test_reduce_prims():
         check_device("vulkan")
         check_device("cuda")
         check_device("opencl")
+        check_device("rocm")
     test_prim(te.sum, np.sum)
     test_prim(tvm.te.min, np.amin)
     test_prim(tvm.te.max, np.amax)
@@ -179,7 +180,7 @@ def test_rfactor_threads():
     check_target("cuda")
     check_target("metal")
     check_target("opencl")
-
+    check_target("rocm")
 
 def test_rfactor_elemwise_threads():
     n = 1025
@@ -230,6 +231,7 @@ def test_rfactor_elemwise_threads():
     check_target("cuda")
     check_target("metal")
     check_target("opencl")
+    check_target("rocm")
 
 def test_argmax():
     def fcombine(x, y):
@@ -337,6 +339,7 @@ def test_rfactor_argmax():
 
     check_target("cuda")
     check_target("vulkan")
+    check_target("rocm")
 
 def test_warp_reduction1():
     nthx = 32
@@ -365,10 +368,10 @@ def test_warp_reduction1():
         s[B].bind(xi, thread_y)
         s[B].bind(xo, block_x)
 
-        print(tvm.lower(s, [A, B], simple_mode=True))
+        tvm.lower(s, [A, B], simple_mode=True)
 
         # validation
-        func = tvm.build(s, [A, B], "cuda", name="warp_reduction")
+        func = tvm.build(s, [A, B], device, name="warp_reduction")
         a_np = np.random.uniform(size=(m,n)).astype(A.dtype)
         b_np = np.zeros((m,), dtype=A.dtype)
         a = tvm.nd.array(a_np, ctx)
@@ -379,6 +382,8 @@ def test_warp_reduction1():
 
     check_target("cuda", m=32, n=256)
     check_target("cuda", m=10, n=20)
+    check_target("rocm", m=32, n=256)
+    check_target("rocm", m=10, n=20)
     # This is a bug in normal reduction.
     # check_target("cuda", m=10, n=37)
 
@@ -437,6 +442,7 @@ def test_warp_reduction2():
         tvm.testing.assert_allclose(t1.asnumpy(), t1_np, rtol=1e-3, atol=1e-3)
 
     check_target("cuda")
+    check_target("rocm")
 
 if __name__ == "__main__":
     test_rfactor_elemwise_threads()
index efc3b4b..bafa957 100644 (file)
@@ -228,86 +228,94 @@ def test_cuda_shuffle():
         tvm.testing.assert_allclose(ndc.asnumpy(), ref)
 
 def test_crossthread_reduction1():
-    if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"):
-        print("skip because cuda is not enabled..")
-        return
-
-    n = te.var("n")
-    m = te.var("m")
-    A = te.placeholder((n, m), name='A')
-    k = te.reduce_axis((0, m), "m")
-    B = te.compute((n,), lambda i: te.sum(A[i, k], axis=k), name="B")
+    def check(device):
+        ctx = tvm.context(device, 0)
+        if not ctx.exist or not tvm.runtime.enabled(device):
+            print("skip because", device, "is not enabled..")
+            return
+        n = te.var("n")
+        m = te.var("m")
+        A = te.placeholder((n, m), name='A')
+        k = te.reduce_axis((0, m), "m")
+        B = te.compute((n,), lambda i: te.sum(A[i, k], axis=k), name="B")
+
+        def sched(nthd):
+            s = te.create_schedule(B.op)
+            ko, _ = s[B].split(B.op.reduce_axis[0], nparts=nthd)
+            s[B].bind(ko, te.thread_axis("threadIdx.x"))
+            s[B].bind(B.op.axis[0], te.thread_axis("blockIdx.x"))
+            func = tvm.build(s, [A, B], device)
+            return func
+
+        def verify(nthd):
+            func = sched(nthd)
+            nn = 3
+            # checks three typical cases
+            vals = [nthd-1, nthd, nthd+1]
+            for kk in [x for x in vals]:
+                size = (nn, kk)
+                a = tvm.nd.array(np.random.uniform(size=size).astype(A.dtype), ctx)
+                b = tvm.nd.array(np.zeros(nn, dtype=B.dtype), ctx)
+                func(a, b)
+                tvm.testing.assert_allclose(b.asnumpy(), \
+                    np.sum(a.asnumpy(), axis=1), rtol=1e-3)
+
+        verify(16)
+        verify(32)
+        verify(64)
+
+    check("cuda")
+    check("rocm")
 
-    def sched(nthd):
-        s = te.create_schedule(B.op)
-        ko, _ = s[B].split(B.op.reduce_axis[0], nparts=nthd)
-        s[B].bind(ko, te.thread_axis("threadIdx.x"))
-        s[B].bind(B.op.axis[0], te.thread_axis("blockIdx.x"))
-        func = tvm.build(s, [A, B], "cuda")
-        return func
-
-    def verify(nthd):
-        func = sched(nthd)
-        nn = 3
-        # checks three typical cases
-        vals = [nthd-1, nthd, nthd+1]
-        for kk in [x for x in vals]:
-            size = (nn, kk)
-            ctx = tvm.context("cuda", 0)
-            a = tvm.nd.array(np.random.uniform(size=size).astype(A.dtype), ctx)
-            b = tvm.nd.array(np.zeros(nn, dtype=B.dtype), ctx)
-            func(a, b)
-            tvm.testing.assert_allclose(b.asnumpy(), \
-                np.sum(a.asnumpy(), axis=1), rtol=1e-3)
-
-    verify(16)
-    verify(32)
-    verify(64)
 
 def test_crossthread_reduction2():
-    if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"):
-        print("skip because cuda is not enabled..")
-        return
-
-    n = te.var("n")
-    k0 = te.var("k0")
-    k1 = te.var("k1")
-    A = te.placeholder((n, k0, k1), name='A')
-    k0 = te.reduce_axis((0, k0), "k0")
-    k1 = te.reduce_axis((0, k1), "k1")
-    B = te.compute((n,), lambda i: te.sum(A[i, k0, k1], axis=(k0, k1)), name="B")
+    def check(device):
+        ctx = tvm.context(device, 0)
+        if not ctx.exist or not tvm.runtime.enabled(device):
+            print("skip because", device, "is not enabled..")
+            return
 
-    def sched(nthdx, nthdy):
-        s = te.create_schedule(B.op)
-        k0o, _ = s[B].split(B.op.reduce_axis[0], nparts=nthdx)
-        k1o, _ = s[B].split(B.op.reduce_axis[1], nparts=nthdy)
-        s[B].bind(k0o, te.thread_axis("threadIdx.x"))
-        s[B].bind(k1o, te.thread_axis("threadIdx.y"))
-        s[B].bind(B.op.axis[0], te.thread_axis("blockIdx.x"))
-        func = tvm.build(s, [A, B], "cuda")
-        return func
-
-    def verify(nthdx, nthdy):
-        func = sched(nthdx, nthdy)
-        nn = 3
-        # checks three typical cases
-        vx = [nthdx-1, nthdx, nthdx+1]
-        vy = [nthdy-1, nthdy, nthdy+1]
-        for kk0, kk1 in [(x, y) for x in vx for y in vy]:
-            size = (nn, kk0, kk1)
-            ctx = tvm.context("cuda", 0)
-            a = tvm.nd.array(np.random.uniform(size=size).astype(A.dtype), ctx)
-            b = tvm.nd.array(np.zeros(nn, dtype=B.dtype), ctx)
-            func(a, b)
-            tvm.testing.assert_allclose(b.asnumpy(), \
-                np.sum(a.asnumpy(), axis=(1, 2)), rtol=1e-3)
-
-    verify(16, 16)
-    verify(32, 32)
-    verify(16, 32)
-    verify(32, 16)
-
-def test_cuda_reducition_binding():
+        n = te.var("n")
+        k0 = te.var("k0")
+        k1 = te.var("k1")
+        A = te.placeholder((n, k0, k1), name='A')
+        k0 = te.reduce_axis((0, k0), "k0")
+        k1 = te.reduce_axis((0, k1), "k1")
+        B = te.compute((n,), lambda i: te.sum(A[i, k0, k1], axis=(k0, k1)), name="B")
+
+        def sched(nthdx, nthdy):
+            s = te.create_schedule(B.op)
+            k0o, _ = s[B].split(B.op.reduce_axis[0], nparts=nthdx)
+            k1o, _ = s[B].split(B.op.reduce_axis[1], nparts=nthdy)
+            s[B].bind(k0o, te.thread_axis("threadIdx.x"))
+            s[B].bind(k1o, te.thread_axis("threadIdx.y"))
+            s[B].bind(B.op.axis[0], te.thread_axis("blockIdx.x"))
+            func = tvm.build(s, [A, B], device)
+            return func
+
+        def verify(nthdx, nthdy):
+            func = sched(nthdx, nthdy)
+            nn = 3
+            # checks three typical cases
+            vx = [nthdx-1, nthdx, nthdx+1]
+            vy = [nthdy-1, nthdy, nthdy+1]
+            for kk0, kk1 in [(x, y) for x in vx for y in vy]:
+                size = (nn, kk0, kk1)
+                a = tvm.nd.array(np.random.uniform(size=size).astype(A.dtype), ctx)
+                b = tvm.nd.array(np.zeros(nn, dtype=B.dtype), ctx)
+                func(a, b)
+                tvm.testing.assert_allclose(b.asnumpy(), \
+                    np.sum(a.asnumpy(), axis=(1, 2)), rtol=1e-3)
+
+        verify(16, 16)
+        verify(32, 32)
+        verify(16, 32)
+        verify(32, 16)
+
+    check("cuda")
+    check("rocm")
+
+def test_cuda_reduction_binding():
     if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"):
         print("skip because cuda is not enabled..")
         return
@@ -327,39 +335,43 @@ def test_cuda_reducition_binding():
     fcuda = tvm.build(s, [A, B], "cuda")
 
 def test_rfactor_predicates():
-    if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"):
-        print("skip because cuda is not enabled..")
-        return
+    def check(device):
+        ctx = tvm.context(device, 0)
+        if not ctx.exist or not tvm.runtime.enabled(device):
+            print("skip because", device, "is not enabled..")
+            return
 
-    n = te.reduce_axis((0, 129), 'n')
-    A = te.placeholder((129,), name='A')
-    B = te.compute( (1, ), lambda b:
-                     te.sum(A[n],
-                             axis=n),
-                     name='B'
-    )
+        n = te.reduce_axis((0, 129), 'n')
+        A = te.placeholder((129,), name='A')
+        B = te.compute( (1, ), lambda b:
+                         te.sum(A[n],
+                                 axis=n),
+                         name='B'
+        )
 
-    s = te.create_schedule(B.op)
+        s = te.create_schedule(B.op)
 
-    _, ni = s[B].split(s[B].op.reduce_axis[0], factor=8)
+        _, ni = s[B].split(s[B].op.reduce_axis[0], factor=8)
 
-    BF = s.rfactor(B, ni, 0)
-    s[B].set_store_predicate(tx.var.equal(0))
+        BF = s.rfactor(B, ni, 0)
+        s[B].set_store_predicate(tx.var.equal(0))
 
-    s[B].bind(s[B].op.reduce_axis[0], tx)
-    s[B].bind(s[B].op.axis[0], bx)
+        s[B].bind(s[B].op.reduce_axis[0], tx)
+        s[B].bind(s[B].op.axis[0], bx)
 
-    s[BF].compute_at(s[B], s[B].op.axis[0])
+        s[BF].compute_at(s[B], s[B].op.axis[0])
 
-    _, noi = s[BF].split(s[BF].op.reduce_axis[0], factor=2)
+        _, noi = s[BF].split(s[BF].op.reduce_axis[0], factor=2)
 
-    BF2 = s.rfactor(BF, noi, 0)
+        BF2 = s.rfactor(BF, noi, 0)
 
-    s[BF].bind(s[BF].op.axis[0], tx)
-    s[BF2].compute_at(s[BF], s[BF].op.axis[1])
+        s[BF].bind(s[BF].op.axis[0], tx)
+        s[BF2].compute_at(s[BF], s[BF].op.axis[1])
 
-    fcuda = tvm.build(s, [A, B], "cuda")
+        fcuda = tvm.build(s, [A, B], device)
 
+    check("cuda")
+    check("rocm")
 
 @unittest.skipIf(not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"), "skip because cuda is not enabled..")
 def test_cuda_const_float_to_half():
@@ -387,11 +399,12 @@ def test_cuda_const_float_to_half():
     np.testing.assert_equal(c.asnumpy(), a_np > b.value)
 
 def test_cuda_reduction():
-    def check_cuda(dtype, m=32, n=32):
-        if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"):
-            print("skip because cuda is not enabled..")
+    def check(device, dtype, m=32, n=32):
+        ctx = tvm.context(device, 0)
+        if not ctx.exist or not tvm.runtime.enabled(device):
+            print("skip because", device, "is not enabled..")
             return
-        if dtype == "float16" and not have_fp16(tvm.gpu(0).compute_version):
+        if dtype == "float16" and not have_fp16(ctx.compute_version):
             print("Skip because gpu does not have fp16 support")
             return
 
@@ -401,10 +414,9 @@ def test_cuda_reduction():
         d = a * b
         e = topi.elemwise_sum([c, d])
         g = topi.sum(e)
-        with tvm.target.cuda():
+        with tvm.target.create(device):
             sg = topi.cuda.schedule_reduce(g)
-            ctx = tvm.gpu(0)
-            func = tvm.build(sg, [a, b, g], 'cuda')
+            func = tvm.build(sg, [a, b, g], device)
             a_np = np.random.uniform(size=(m, n)).astype(a.dtype)
             b_np = np.random.uniform(size=(m, n)).astype(b.dtype)
             g_np = np.sum(np.add(a_np * b_np, a_np + b_np))
@@ -414,26 +426,27 @@ def test_cuda_reduction():
             func(a_nd, b_nd, g_nd)
             tvm.testing.assert_allclose(g_nd.asnumpy(), g_np, rtol=1e-3)
 
-    check_cuda("float32")
-    check_cuda("float16")
+    check("cuda", "float32")
+    check("rocm", "float32")
+    check("cuda", "float16")
 
 def test_cuda_mix_threaded_and_normal_reduction():
-    def check_cuda(dtype, m=32, n=32):
-        if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"):
-            print("skip because cuda is not enabled..")
+    def check(device, dtype, m=32, n=32):
+        ctx = tvm.context(device, 0)
+        if not ctx.exist or not tvm.runtime.enabled(device):
+            print("skip because", device, "is not enabled..")
             return
-        if dtype == "float16" and not have_fp16(tvm.gpu(0).compute_version):
+        if dtype == "float16" and not have_fp16(ctx.compute_version):
             print("Skip because gpu does not have fp16 support")
             return
 
         a = tvm.te.placeholder((m, n), name="a", dtype=dtype)
         b = topi.sum(a)
-        with tvm.target.cuda():
+        with tvm.target.create(device):
             sb = tvm.te.create_schedule(b.op)
             i, _ = b.op.reduce_axis
             sb[b].bind(i, tvm.te.thread_axis("threadIdx.x"))
-            ctx = tvm.gpu(0)
-            func = tvm.build(sb, [a, b], 'cuda')
+            func = tvm.build(sb, [a, b], device)
             a_np = np.random.uniform(size=(m, n)).astype(a.dtype)
             b_np = np.sum(a_np)
             a_nd = tvm.nd.array(a_np, ctx)
@@ -441,8 +454,9 @@ def test_cuda_mix_threaded_and_normal_reduction():
             func(a_nd, b_nd)
             tvm.testing.assert_allclose(b_nd.asnumpy(), b_np, rtol=1e-3)
 
-    check_cuda("float32")
-    check_cuda("float16")
+    check("cuda", "float32")
+    check("rocm", "float32")
+    check("cuda", "float16")
 
 def test_cuda_floordiv_with_vectorization():
     if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"):
@@ -687,7 +701,7 @@ if __name__ == "__main__":
     test_cuda_inf_nan()
     test_cuda_shuffle()
     test_vectorized_casts()
-    test_cuda_reducition_binding()
+    test_cuda_reduction_binding()
     test_crossthread_reduction1()
     test_crossthread_reduction2()
     test_rfactor_predicates()
index f107e59..4c6304a 100644 (file)
@@ -76,7 +76,7 @@ def test_rocm_inf_nan():
     check_inf_nan(ctx, 1, float('nan'), 'float64')
 
 @unittest.skipIf(not tvm.rocm(0).exist or not tvm.runtime.enabled("rocm"), "skip because rocm is not enabled..")
-def test_rocm_reducition_binding():
+def test_rocm_reduction_binding():
     k = te.reduce_axis((0, 32), 'k')
     A = te.placeholder((96, 32), name='A')
     B = te.compute( (96,), lambda m:
@@ -132,6 +132,6 @@ def test_rocm_vectorize_add():
 if __name__ == "__main__":
     test_rocm_cross_thread_reduction()
     test_rocm_inf_nan()
-    test_rocm_reducition_binding()
+    test_rocm_reduction_binding()
     test_rocm_copy()
     test_rocm_vectorize_add()
index ce9dd56..51da6ea 100644 (file)
@@ -219,15 +219,11 @@ def test_lower_warp_memory_cuda_2_buffers():
     check_cuda("float16")
 
 def test_lower_warp_memory_roundup():
-    if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"):
-        print("skip because cuda is not enabled..")
-        return
-
-    def check(m):
+    def check(device, m):
         A = te.placeholder((m,), name='A')
         B = te.compute((m,), lambda i: A[i] + 1, name='B')
 
-        with tvm.target.create("cuda"):
+        with tvm.target.create(device):
             s = te.create_schedule(B.op)
             xo, xi = s[B].split(B.op.axis[0], factor=32)
             tx = te.thread_axis("threadIdx.x")
@@ -239,8 +235,8 @@ def test_lower_warp_memory_roundup():
             s[AA].bind(yi, tx)
             s[AA].compute_at(s[B], xo)
 
-            ctx = tvm.gpu(0)
-            func = tvm.build(s, [A, B], "cuda")
+            ctx = tvm.context(device, 0)
+            func = tvm.build(s, [A, B], device)
             A_np = np.random.uniform(size=(m,)).astype(A.dtype)
             B_np = np.zeros(shape=(m,)).astype(B.dtype)
             A_nd = tvm.nd.array(A_np, ctx)
@@ -249,9 +245,16 @@ def test_lower_warp_memory_roundup():
             B_np = A_np + 1
             tvm.testing.assert_allclose(B_nd.asnumpy(), B_np)
 
-    check(m=31)
-    check(m=32)
-    check(m=33)
+    for device in ['cuda', 'rocm']:
+        if not tvm.context(device, 0).exist or not tvm.runtime.enabled(device):
+            print("skip because", device,"is not enabled..")
+            continue
+        check(device, m=31)
+        check(device, m=32)
+        check(device, m=33)
+        check(device, m=63)
+        check(device, m=64)
+        check(device, m=65)
 
 if __name__ == "__main__":
     test_lower_warp_memory_local_scope()
index 50e2b0d..5f7402b 100644 (file)
@@ -54,11 +54,12 @@ def schedule_softmax(outs):
         raise ValueError('Tag is expected to be softmax_output or log_softmax_output. \
                          Got {0}'.format(op_tag))
 
-    # The nvptx backend only supports 32-bits warp shuffle instructions.
+    # The nvptx and rocm backends only supports 32-bits warp shuffle
+    # instructions.
     #
     # TODO(tvm-team) Fix nvptx codegen or deprecate nvptx backend.
     def sched_warp_softmax():
-        if tgt.target_name == "nvptx":
+        if tgt.target_name == "nvptx" or tgt.target_name == "rocm":
             return softmax.dtype == "float32" or softmax.dtype == "int32"
         if tgt.target_name != "cuda":
             # this is used as the gpu schedule for other arches which may not have warp reductions