From: Thomas Viehmann Date: Fri, 5 Jun 2020 09:49:37 +0000 (+0200) Subject: ROCm: Add warp shuffles and enable reductions (#5727) X-Git-Tag: upstream/0.7.0~607 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=e1b11712ac09c32614483d24a4c7e0245ee4cb4b;p=platform%2Fupstream%2Ftvm.git ROCm: Add warp shuffles and enable reductions (#5727) Thank you @masahi and @wpan11nv for the feedback --- diff --git a/src/target/llvm/intrin_rule_rocm.cc b/src/target/llvm/intrin_rule_rocm.cc index 52447a1..22af9f1 100644 --- a/src/target/llvm/intrin_rule_rocm.cc +++ b/src/target/llvm/intrin_rule_rocm.cc @@ -24,6 +24,7 @@ #include #include +#include #include @@ -40,8 +41,59 @@ inline void DispatchExternOCML(const TVMArgs& args, TVMRetValue* rv) { *rv = CallNode::make(call->dtype, intrinsic_name.str(), call->args, CallNode::PureExtern); } +inline void DispatchShuffle(const TVMArgs& targs, TVMRetValue* rv) { + PrimExpr e_call = targs[0]; + using namespace tir; + const CallNode* call = e_call.as(); + CHECK(call != nullptr); + CHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size + PrimExpr var = call->args[1]; + CHECK_EQ(var.dtype().bits(), 32); + + // get own lane in self (__lane_id) + PrimExpr minus_one = tir::make_const(DataType::Int(32), -1); + PrimExpr zero = tir::make_zero(DataType::Int(32)); + PrimExpr lo = CallNode::make(DataType::Int(32), "llvm.amdgcn.mbcnt.lo", {minus_one, zero}, + CallNode::PureExtern); + PrimExpr self = CallNode::make(DataType::Int(32), "llvm.amdgcn.mbcnt.hi", {minus_one, lo}, + CallNode::PureExtern); + + // compute lane to get from + PrimExpr width = call->args[3]; + PrimExpr index; + if (call->name == "tvm_warp_shuffle") { + PrimExpr src_lane = call->args[2]; + index = src_lane + (self & ~(width - 1)); + } else if (call->name == "tvm_warp_shuffle_up") { + PrimExpr delta = call->args[2]; + index = self - delta; + index = SelectNode::make(index < (self & ~(width - 1)), self, index); + } else { + CHECK_EQ(call->name, "tvm_warp_shuffle_down"); + PrimExpr delta = call->args[2]; + index = self + delta; + index = SelectNode::make((self & (width - 1)) + delta >= width, self, index); + } + PrimExpr res = CallNode::make(var.dtype(), "llvm.amdgcn.ds.bpermute", {index << 2, var}, + CallNode::PureExtern); + *rv = res; +} + namespace llvm { +// dummy because we don't have the activemask +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.tvm_warp_activemask") + .set_body([](const TVMArgs& targs, TVMRetValue* rv) { + PrimExpr zero = tir::make_zero(DataType::Int(32)); + *rv = zero; + }); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.tvm_warp_shuffle").set_body(DispatchShuffle); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.tvm_warp_shuffle_up").set_body(DispatchShuffle); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.tvm_warp_shuffle_down").set_body(DispatchShuffle); + TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.floor").set_body(DispatchExternOCML); TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.ceil").set_body(DispatchExternOCML); diff --git a/src/target/target.cc b/src/target/target.cc index f3ade8d..2104c2e 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -98,8 +98,9 @@ Target CreateTarget(const std::string& target_name, const std::vectordevice_type = kDLOpenCL; - } else { + } else { // rocm t->device_type = kDLROCM; + t->thread_warp_size = 64; } t->keys_array.push_back(target_name); t->keys_array.push_back("gpu"); diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index 127b012..de86647 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -196,7 +196,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // // Allocate reduction vars v[i], i = 0..size-1 // - // for offset from 16 to 1 by 2 + // for offset from WARP_SIZE to 1 by 2 // // a <- load(v[i]) // b <- shuffle_down(load(v[i], offset)) @@ -244,7 +244,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { } // Emit reductions within a warp. - for (int offset = 16; offset > 0; offset /= 2) { + for (int offset = warp_size_ / 2; offset > 0; offset /= 2) { // Load reduction values, no synchronization needed. Array a, b; for (size_t i = 0; i < size; ++i) { @@ -478,9 +478,20 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // the warp size. // // TODO(tvm-team) reduction with a sub-warp of 8 or 16 threads. + // Note: The ROCm backend will only have warp reductions for now. + // Also, the warp/wavefront size differs (64 on rocm, 32 on cuda). bool is_warp_reduction(const std::vector& types) const { // Only cuda target supports warp reductions. - if (target_->target_name != "cuda") return false; + if ((target_->target_name != "cuda") && (target_->target_name != "rocm")) return false; + + // rocm only supports 32 bit operands for shuffling at the moment + if ((target_->target_name == "rocm") && + (std::any_of(types.begin(), types.end(), [](DataType ty) { + if (ty.is_vector()) return true; + return ty.bits() != 32; + }))) { + return false; + } // Supported types: // {u}int, {u}long, {u}long long, float, double, half/half2 diff --git a/tests/python/integration/test_reduce.py b/tests/python/integration/test_reduce.py index 7ac3496..c5d9d08 100644 --- a/tests/python/integration/test_reduce.py +++ b/tests/python/integration/test_reduce.py @@ -65,6 +65,7 @@ def test_reduce_prims(): check_device("vulkan") check_device("cuda") check_device("opencl") + check_device("rocm") test_prim(te.sum, np.sum) test_prim(tvm.te.min, np.amin) test_prim(tvm.te.max, np.amax) @@ -179,7 +180,7 @@ def test_rfactor_threads(): check_target("cuda") check_target("metal") check_target("opencl") - + check_target("rocm") def test_rfactor_elemwise_threads(): n = 1025 @@ -230,6 +231,7 @@ def test_rfactor_elemwise_threads(): check_target("cuda") check_target("metal") check_target("opencl") + check_target("rocm") def test_argmax(): def fcombine(x, y): @@ -337,6 +339,7 @@ def test_rfactor_argmax(): check_target("cuda") check_target("vulkan") + check_target("rocm") def test_warp_reduction1(): nthx = 32 @@ -365,10 +368,10 @@ def test_warp_reduction1(): s[B].bind(xi, thread_y) s[B].bind(xo, block_x) - print(tvm.lower(s, [A, B], simple_mode=True)) + tvm.lower(s, [A, B], simple_mode=True) # validation - func = tvm.build(s, [A, B], "cuda", name="warp_reduction") + func = tvm.build(s, [A, B], device, name="warp_reduction") a_np = np.random.uniform(size=(m,n)).astype(A.dtype) b_np = np.zeros((m,), dtype=A.dtype) a = tvm.nd.array(a_np, ctx) @@ -379,6 +382,8 @@ def test_warp_reduction1(): check_target("cuda", m=32, n=256) check_target("cuda", m=10, n=20) + check_target("rocm", m=32, n=256) + check_target("rocm", m=10, n=20) # This is a bug in normal reduction. # check_target("cuda", m=10, n=37) @@ -437,6 +442,7 @@ def test_warp_reduction2(): tvm.testing.assert_allclose(t1.asnumpy(), t1_np, rtol=1e-3, atol=1e-3) check_target("cuda") + check_target("rocm") if __name__ == "__main__": test_rfactor_elemwise_threads() diff --git a/tests/python/unittest/test_target_codegen_cuda.py b/tests/python/unittest/test_target_codegen_cuda.py index efc3b4b..bafa957 100644 --- a/tests/python/unittest/test_target_codegen_cuda.py +++ b/tests/python/unittest/test_target_codegen_cuda.py @@ -228,86 +228,94 @@ def test_cuda_shuffle(): tvm.testing.assert_allclose(ndc.asnumpy(), ref) def test_crossthread_reduction1(): - if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"): - print("skip because cuda is not enabled..") - return - - n = te.var("n") - m = te.var("m") - A = te.placeholder((n, m), name='A') - k = te.reduce_axis((0, m), "m") - B = te.compute((n,), lambda i: te.sum(A[i, k], axis=k), name="B") + def check(device): + ctx = tvm.context(device, 0) + if not ctx.exist or not tvm.runtime.enabled(device): + print("skip because", device, "is not enabled..") + return + n = te.var("n") + m = te.var("m") + A = te.placeholder((n, m), name='A') + k = te.reduce_axis((0, m), "m") + B = te.compute((n,), lambda i: te.sum(A[i, k], axis=k), name="B") + + def sched(nthd): + s = te.create_schedule(B.op) + ko, _ = s[B].split(B.op.reduce_axis[0], nparts=nthd) + s[B].bind(ko, te.thread_axis("threadIdx.x")) + s[B].bind(B.op.axis[0], te.thread_axis("blockIdx.x")) + func = tvm.build(s, [A, B], device) + return func + + def verify(nthd): + func = sched(nthd) + nn = 3 + # checks three typical cases + vals = [nthd-1, nthd, nthd+1] + for kk in [x for x in vals]: + size = (nn, kk) + a = tvm.nd.array(np.random.uniform(size=size).astype(A.dtype), ctx) + b = tvm.nd.array(np.zeros(nn, dtype=B.dtype), ctx) + func(a, b) + tvm.testing.assert_allclose(b.asnumpy(), \ + np.sum(a.asnumpy(), axis=1), rtol=1e-3) + + verify(16) + verify(32) + verify(64) + + check("cuda") + check("rocm") - def sched(nthd): - s = te.create_schedule(B.op) - ko, _ = s[B].split(B.op.reduce_axis[0], nparts=nthd) - s[B].bind(ko, te.thread_axis("threadIdx.x")) - s[B].bind(B.op.axis[0], te.thread_axis("blockIdx.x")) - func = tvm.build(s, [A, B], "cuda") - return func - - def verify(nthd): - func = sched(nthd) - nn = 3 - # checks three typical cases - vals = [nthd-1, nthd, nthd+1] - for kk in [x for x in vals]: - size = (nn, kk) - ctx = tvm.context("cuda", 0) - a = tvm.nd.array(np.random.uniform(size=size).astype(A.dtype), ctx) - b = tvm.nd.array(np.zeros(nn, dtype=B.dtype), ctx) - func(a, b) - tvm.testing.assert_allclose(b.asnumpy(), \ - np.sum(a.asnumpy(), axis=1), rtol=1e-3) - - verify(16) - verify(32) - verify(64) def test_crossthread_reduction2(): - if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"): - print("skip because cuda is not enabled..") - return - - n = te.var("n") - k0 = te.var("k0") - k1 = te.var("k1") - A = te.placeholder((n, k0, k1), name='A') - k0 = te.reduce_axis((0, k0), "k0") - k1 = te.reduce_axis((0, k1), "k1") - B = te.compute((n,), lambda i: te.sum(A[i, k0, k1], axis=(k0, k1)), name="B") + def check(device): + ctx = tvm.context(device, 0) + if not ctx.exist or not tvm.runtime.enabled(device): + print("skip because", device, "is not enabled..") + return - def sched(nthdx, nthdy): - s = te.create_schedule(B.op) - k0o, _ = s[B].split(B.op.reduce_axis[0], nparts=nthdx) - k1o, _ = s[B].split(B.op.reduce_axis[1], nparts=nthdy) - s[B].bind(k0o, te.thread_axis("threadIdx.x")) - s[B].bind(k1o, te.thread_axis("threadIdx.y")) - s[B].bind(B.op.axis[0], te.thread_axis("blockIdx.x")) - func = tvm.build(s, [A, B], "cuda") - return func - - def verify(nthdx, nthdy): - func = sched(nthdx, nthdy) - nn = 3 - # checks three typical cases - vx = [nthdx-1, nthdx, nthdx+1] - vy = [nthdy-1, nthdy, nthdy+1] - for kk0, kk1 in [(x, y) for x in vx for y in vy]: - size = (nn, kk0, kk1) - ctx = tvm.context("cuda", 0) - a = tvm.nd.array(np.random.uniform(size=size).astype(A.dtype), ctx) - b = tvm.nd.array(np.zeros(nn, dtype=B.dtype), ctx) - func(a, b) - tvm.testing.assert_allclose(b.asnumpy(), \ - np.sum(a.asnumpy(), axis=(1, 2)), rtol=1e-3) - - verify(16, 16) - verify(32, 32) - verify(16, 32) - verify(32, 16) - -def test_cuda_reducition_binding(): + n = te.var("n") + k0 = te.var("k0") + k1 = te.var("k1") + A = te.placeholder((n, k0, k1), name='A') + k0 = te.reduce_axis((0, k0), "k0") + k1 = te.reduce_axis((0, k1), "k1") + B = te.compute((n,), lambda i: te.sum(A[i, k0, k1], axis=(k0, k1)), name="B") + + def sched(nthdx, nthdy): + s = te.create_schedule(B.op) + k0o, _ = s[B].split(B.op.reduce_axis[0], nparts=nthdx) + k1o, _ = s[B].split(B.op.reduce_axis[1], nparts=nthdy) + s[B].bind(k0o, te.thread_axis("threadIdx.x")) + s[B].bind(k1o, te.thread_axis("threadIdx.y")) + s[B].bind(B.op.axis[0], te.thread_axis("blockIdx.x")) + func = tvm.build(s, [A, B], device) + return func + + def verify(nthdx, nthdy): + func = sched(nthdx, nthdy) + nn = 3 + # checks three typical cases + vx = [nthdx-1, nthdx, nthdx+1] + vy = [nthdy-1, nthdy, nthdy+1] + for kk0, kk1 in [(x, y) for x in vx for y in vy]: + size = (nn, kk0, kk1) + a = tvm.nd.array(np.random.uniform(size=size).astype(A.dtype), ctx) + b = tvm.nd.array(np.zeros(nn, dtype=B.dtype), ctx) + func(a, b) + tvm.testing.assert_allclose(b.asnumpy(), \ + np.sum(a.asnumpy(), axis=(1, 2)), rtol=1e-3) + + verify(16, 16) + verify(32, 32) + verify(16, 32) + verify(32, 16) + + check("cuda") + check("rocm") + +def test_cuda_reduction_binding(): if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"): print("skip because cuda is not enabled..") return @@ -327,39 +335,43 @@ def test_cuda_reducition_binding(): fcuda = tvm.build(s, [A, B], "cuda") def test_rfactor_predicates(): - if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"): - print("skip because cuda is not enabled..") - return + def check(device): + ctx = tvm.context(device, 0) + if not ctx.exist or not tvm.runtime.enabled(device): + print("skip because", device, "is not enabled..") + return - n = te.reduce_axis((0, 129), 'n') - A = te.placeholder((129,), name='A') - B = te.compute( (1, ), lambda b: - te.sum(A[n], - axis=n), - name='B' - ) + n = te.reduce_axis((0, 129), 'n') + A = te.placeholder((129,), name='A') + B = te.compute( (1, ), lambda b: + te.sum(A[n], + axis=n), + name='B' + ) - s = te.create_schedule(B.op) + s = te.create_schedule(B.op) - _, ni = s[B].split(s[B].op.reduce_axis[0], factor=8) + _, ni = s[B].split(s[B].op.reduce_axis[0], factor=8) - BF = s.rfactor(B, ni, 0) - s[B].set_store_predicate(tx.var.equal(0)) + BF = s.rfactor(B, ni, 0) + s[B].set_store_predicate(tx.var.equal(0)) - s[B].bind(s[B].op.reduce_axis[0], tx) - s[B].bind(s[B].op.axis[0], bx) + s[B].bind(s[B].op.reduce_axis[0], tx) + s[B].bind(s[B].op.axis[0], bx) - s[BF].compute_at(s[B], s[B].op.axis[0]) + s[BF].compute_at(s[B], s[B].op.axis[0]) - _, noi = s[BF].split(s[BF].op.reduce_axis[0], factor=2) + _, noi = s[BF].split(s[BF].op.reduce_axis[0], factor=2) - BF2 = s.rfactor(BF, noi, 0) + BF2 = s.rfactor(BF, noi, 0) - s[BF].bind(s[BF].op.axis[0], tx) - s[BF2].compute_at(s[BF], s[BF].op.axis[1]) + s[BF].bind(s[BF].op.axis[0], tx) + s[BF2].compute_at(s[BF], s[BF].op.axis[1]) - fcuda = tvm.build(s, [A, B], "cuda") + fcuda = tvm.build(s, [A, B], device) + check("cuda") + check("rocm") @unittest.skipIf(not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"), "skip because cuda is not enabled..") def test_cuda_const_float_to_half(): @@ -387,11 +399,12 @@ def test_cuda_const_float_to_half(): np.testing.assert_equal(c.asnumpy(), a_np > b.value) def test_cuda_reduction(): - def check_cuda(dtype, m=32, n=32): - if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"): - print("skip because cuda is not enabled..") + def check(device, dtype, m=32, n=32): + ctx = tvm.context(device, 0) + if not ctx.exist or not tvm.runtime.enabled(device): + print("skip because", device, "is not enabled..") return - if dtype == "float16" and not have_fp16(tvm.gpu(0).compute_version): + if dtype == "float16" and not have_fp16(ctx.compute_version): print("Skip because gpu does not have fp16 support") return @@ -401,10 +414,9 @@ def test_cuda_reduction(): d = a * b e = topi.elemwise_sum([c, d]) g = topi.sum(e) - with tvm.target.cuda(): + with tvm.target.create(device): sg = topi.cuda.schedule_reduce(g) - ctx = tvm.gpu(0) - func = tvm.build(sg, [a, b, g], 'cuda') + func = tvm.build(sg, [a, b, g], device) a_np = np.random.uniform(size=(m, n)).astype(a.dtype) b_np = np.random.uniform(size=(m, n)).astype(b.dtype) g_np = np.sum(np.add(a_np * b_np, a_np + b_np)) @@ -414,26 +426,27 @@ def test_cuda_reduction(): func(a_nd, b_nd, g_nd) tvm.testing.assert_allclose(g_nd.asnumpy(), g_np, rtol=1e-3) - check_cuda("float32") - check_cuda("float16") + check("cuda", "float32") + check("rocm", "float32") + check("cuda", "float16") def test_cuda_mix_threaded_and_normal_reduction(): - def check_cuda(dtype, m=32, n=32): - if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"): - print("skip because cuda is not enabled..") + def check(device, dtype, m=32, n=32): + ctx = tvm.context(device, 0) + if not ctx.exist or not tvm.runtime.enabled(device): + print("skip because", device, "is not enabled..") return - if dtype == "float16" and not have_fp16(tvm.gpu(0).compute_version): + if dtype == "float16" and not have_fp16(ctx.compute_version): print("Skip because gpu does not have fp16 support") return a = tvm.te.placeholder((m, n), name="a", dtype=dtype) b = topi.sum(a) - with tvm.target.cuda(): + with tvm.target.create(device): sb = tvm.te.create_schedule(b.op) i, _ = b.op.reduce_axis sb[b].bind(i, tvm.te.thread_axis("threadIdx.x")) - ctx = tvm.gpu(0) - func = tvm.build(sb, [a, b], 'cuda') + func = tvm.build(sb, [a, b], device) a_np = np.random.uniform(size=(m, n)).astype(a.dtype) b_np = np.sum(a_np) a_nd = tvm.nd.array(a_np, ctx) @@ -441,8 +454,9 @@ def test_cuda_mix_threaded_and_normal_reduction(): func(a_nd, b_nd) tvm.testing.assert_allclose(b_nd.asnumpy(), b_np, rtol=1e-3) - check_cuda("float32") - check_cuda("float16") + check("cuda", "float32") + check("rocm", "float32") + check("cuda", "float16") def test_cuda_floordiv_with_vectorization(): if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"): @@ -687,7 +701,7 @@ if __name__ == "__main__": test_cuda_inf_nan() test_cuda_shuffle() test_vectorized_casts() - test_cuda_reducition_binding() + test_cuda_reduction_binding() test_crossthread_reduction1() test_crossthread_reduction2() test_rfactor_predicates() diff --git a/tests/python/unittest/test_target_codegen_rocm.py b/tests/python/unittest/test_target_codegen_rocm.py index f107e59..4c6304a 100644 --- a/tests/python/unittest/test_target_codegen_rocm.py +++ b/tests/python/unittest/test_target_codegen_rocm.py @@ -76,7 +76,7 @@ def test_rocm_inf_nan(): check_inf_nan(ctx, 1, float('nan'), 'float64') @unittest.skipIf(not tvm.rocm(0).exist or not tvm.runtime.enabled("rocm"), "skip because rocm is not enabled..") -def test_rocm_reducition_binding(): +def test_rocm_reduction_binding(): k = te.reduce_axis((0, 32), 'k') A = te.placeholder((96, 32), name='A') B = te.compute( (96,), lambda m: @@ -132,6 +132,6 @@ def test_rocm_vectorize_add(): if __name__ == "__main__": test_rocm_cross_thread_reduction() test_rocm_inf_nan() - test_rocm_reducition_binding() + test_rocm_reduction_binding() test_rocm_copy() test_rocm_vectorize_add() diff --git a/tests/python/unittest/test_tir_transform_lower_warp_memory.py b/tests/python/unittest/test_tir_transform_lower_warp_memory.py index ce9dd56..51da6ea 100644 --- a/tests/python/unittest/test_tir_transform_lower_warp_memory.py +++ b/tests/python/unittest/test_tir_transform_lower_warp_memory.py @@ -219,15 +219,11 @@ def test_lower_warp_memory_cuda_2_buffers(): check_cuda("float16") def test_lower_warp_memory_roundup(): - if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"): - print("skip because cuda is not enabled..") - return - - def check(m): + def check(device, m): A = te.placeholder((m,), name='A') B = te.compute((m,), lambda i: A[i] + 1, name='B') - with tvm.target.create("cuda"): + with tvm.target.create(device): s = te.create_schedule(B.op) xo, xi = s[B].split(B.op.axis[0], factor=32) tx = te.thread_axis("threadIdx.x") @@ -239,8 +235,8 @@ def test_lower_warp_memory_roundup(): s[AA].bind(yi, tx) s[AA].compute_at(s[B], xo) - ctx = tvm.gpu(0) - func = tvm.build(s, [A, B], "cuda") + ctx = tvm.context(device, 0) + func = tvm.build(s, [A, B], device) A_np = np.random.uniform(size=(m,)).astype(A.dtype) B_np = np.zeros(shape=(m,)).astype(B.dtype) A_nd = tvm.nd.array(A_np, ctx) @@ -249,9 +245,16 @@ def test_lower_warp_memory_roundup(): B_np = A_np + 1 tvm.testing.assert_allclose(B_nd.asnumpy(), B_np) - check(m=31) - check(m=32) - check(m=33) + for device in ['cuda', 'rocm']: + if not tvm.context(device, 0).exist or not tvm.runtime.enabled(device): + print("skip because", device,"is not enabled..") + continue + check(device, m=31) + check(device, m=32) + check(device, m=33) + check(device, m=63) + check(device, m=64) + check(device, m=65) if __name__ == "__main__": test_lower_warp_memory_local_scope() diff --git a/topi/python/topi/cuda/softmax.py b/topi/python/topi/cuda/softmax.py index 50e2b0d..5f7402b 100644 --- a/topi/python/topi/cuda/softmax.py +++ b/topi/python/topi/cuda/softmax.py @@ -54,11 +54,12 @@ def schedule_softmax(outs): raise ValueError('Tag is expected to be softmax_output or log_softmax_output. \ Got {0}'.format(op_tag)) - # The nvptx backend only supports 32-bits warp shuffle instructions. + # The nvptx and rocm backends only supports 32-bits warp shuffle + # instructions. # # TODO(tvm-team) Fix nvptx codegen or deprecate nvptx backend. def sched_warp_softmax(): - if tgt.target_name == "nvptx": + if tgt.target_name == "nvptx" or tgt.target_name == "rocm": return softmax.dtype == "float32" or softmax.dtype == "int32" if tgt.target_name != "cuda": # this is used as the gpu schedule for other arches which may not have warp reductions