const CallNode* call = e.as<CallNode>();
CHECK(call != nullptr);
CHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size
- Array<PrimExpr> cuda_args{{call->args[0], call->args[1], call->args[2]}};
+ Array<PrimExpr> cuda_args{{call->args[0], call->args[1], call->args[2], call->args[3]}};
const char* name = T()(call->dtype, call->name);
*rv = CallNode::make(call->dtype, name, cuda_args, CallNode::PureExtern);
}
print("Skip because gpu does not have fp16 support")
return
- m = 16
- A = te.placeholder((m,), name='A', dtype=dtype)
- B = te.compute((m,), lambda i: A[(i + 1) % m], name='B')
+ n, m = 16, 16
+ A = te.placeholder((n, m,), name='A', dtype=dtype)
+ B = te.compute((n, m,), lambda j, i: A[j, (i + 1) % m], name='B')
cuda_target = tvm.target.create("cuda")
assert cuda_target.thread_warp_size == 2 * m
with cuda_target:
s = te.create_schedule(B.op)
tx = te.thread_axis("threadIdx.x")
+ ty = te.thread_axis("threadIdx.y")
bx = te.thread_axis("blockIdx.x")
AA = s.cache_read(A, "warp", [B])
- xo, xi = s[B].split(B.op.axis[0], nparts=1)
- s[B].bind(xi, tx)
- s[B].bind(xo, bx)
- s[AA].compute_at(s[B], xo)
- xo, xi = s[AA].split(s[AA].op.axis[0], nparts=1)
- s[AA].bind(xo, bx)
- s[AA].bind(xi, tx)
+ y, x = B.op.axis
+ z, y = s[B].split(y, nparts=2)
+ s[B].bind(x, tx)
+ s[B].bind(y, ty)
+ s[B].bind(z, bx)
+ s[AA].compute_at(s[B], y)
+ _, x = AA.op.axis
+ s[AA].bind(x, tx)
ctx = tvm.gpu(0)
func = tvm.build(s, [A, B], "cuda")
- A_np = np.array(list(range(m)), dtype=dtype)
- B_np = np.array(list(range(1, m)) + [0], dtype=dtype)
+ A_np = np.array([list(range(i, m + i)) for i in range(n)], dtype=dtype)
+ B_np = np.array([list(range(1 + i, m + i)) + [i] for i in range(n)], dtype=dtype)
A_nd = tvm.nd.array(A_np, ctx)
B_nd = tvm.nd.array(np.zeros(B_np.shape, dtype=B_np.dtype), ctx)
func(A_nd, B_nd)