From: Zijing Gu Date: Wed, 10 Jun 2020 17:07:36 +0000 (-0400) Subject: [topi] block sparse dense on cuda (#5746) X-Git-Tag: upstream/0.7.0~589 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=ed583092dbeb4f1b0458ad015f607f0746d61e80;p=platform%2Fupstream%2Ftvm.git [topi] block sparse dense on cuda (#5746) --- diff --git a/topi/python/topi/cuda/__init__.py b/topi/python/topi/cuda/__init__.py index ba5c54b..78e3680 100644 --- a/topi/python/topi/cuda/__init__.py +++ b/topi/python/topi/cuda/__init__.py @@ -50,3 +50,4 @@ from .conv2d_nhwc_tensorcore import * from .conv3d_ndhwc_tensorcore import * from .dense_tensorcore import * from .correlation import * +from .sparse import * diff --git a/topi/python/topi/cuda/sparse.py b/topi/python/topi/cuda/sparse.py new file mode 100644 index 0000000..037eea4 --- /dev/null +++ b/topi/python/topi/cuda/sparse.py @@ -0,0 +1,94 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Sparse operators""" +from tvm import te +from tvm import autotvm +from tvm.autotvm.task.space import SplitEntity +from ..util import traverse_inline +from .. import nn + + +@autotvm.register_topi_compute("sparse_dense.cuda") +def sparse_dense(cfg, data, weight_data, weight_indices, weight_indptr): + """ + Computes sparse-dense matrix multiplication of `data` and + `(weight_data, weight_indices, weight_indptr).T` + + Parameters + ---------- + cfg: ConfigEntity + The config for this template + + data : tvm.te.Tensor + 2-D with shape [M, K], float32 + + weight_data : tvm.te.Tensor + 1-D with shape [nnz] (CSR) or + 3-D with shape [num_blocks, bs_r, bs_c] (BSR) + + weight_indices : tvm.te.Tensor + 1-D with shape [nnz] (CSR) or + 1-D with shape [num_blocks] (BSR) + + weight_indptr : tvm.te.Tensor + 1-D with shape [N + 1] (CSR) or + 1-D with shape [(N + 1) // bs_r] (BSR) + + Returns + ------- + output : tvm.te.Tensor + 2-D with shape [M, N] + """ + # pylint:disable=unused-argument + return nn.sparse_dense(data, weight_data, weight_indices, weight_indptr) + + +@autotvm.register_topi_schedule("sparse_dense.cuda") +def schedule_sparse_dense(cfg, outs): + """Create schedule for sparse dense""" + # pylint:disable=invalid-name + s = te.create_schedule([x.op for x in outs]) + + def _callback(op): + if op.tag == "sparse_dense_bsrmm": + y_bsrmm = op.input_tensors[0] + assert y_bsrmm.op.tag == "sparse_dense_bsrmm_block" + out = s.outputs[0].output(0) + (_, c) = s[y_bsrmm].op.reduce_axis + + (m_o, n_o) = s[out].op.axis + s[out].bind(m_o, te.thread_axis("blockIdx.x")) + s[out].bind(n_o, te.thread_axis("blockIdx.y")) + s[y_bsrmm].compute_at(s[out], n_o) + + thread_x = te.thread_axis("threadIdx.x") + + cfg.define_split("tile_c", c, num_outputs=2) + if cfg.is_fallback: + cfg["tile_c"] = SplitEntity([-1, 8]) + _, ci = cfg['tile_c'].apply(s, y_bsrmm, c) + + y_bsrmm_factored = s.rfactor(y_bsrmm, ci) + tx = s[y_bsrmm].op.reduce_axis[0] + s[y_bsrmm].bind(tx, thread_x) + s[y_bsrmm_factored].compute_at(s[y_bsrmm], tx) + s[y_bsrmm].set_store_predicate(thread_x.var.equal(0)) + s[out].set_store_predicate(thread_x.var.equal(0)) + + traverse_inline(s, outs[0].op, _callback) + return s diff --git a/topi/python/topi/nn/sparse.py b/topi/python/topi/nn/sparse.py index b37bac2..b24121b 100644 --- a/topi/python/topi/nn/sparse.py +++ b/topi/python/topi/nn/sparse.py @@ -30,7 +30,7 @@ def sparse_dense(data, weight_data, weight_indices, weight_indptr): Parameters ---------- - x : tvm.te.Tensor + data : tvm.te.Tensor 2-D with shape [M, K], float32 weight_data : tvm.te.Tensor diff --git a/topi/tests/python/test_topi_sparse.py b/topi/tests/python/test_topi_sparse.py index fc2d26b..3290fc0 100644 --- a/topi/tests/python/test_topi_sparse.py +++ b/topi/tests/python/test_topi_sparse.py @@ -26,6 +26,12 @@ from collections import namedtuple import time import scipy.sparse as sp +_sparse_dense_implement = { + "generic": (topi.nn.sparse_dense, topi.generic.schedule_sparse_dense), + "cuda": (topi.cuda.sparse_dense, topi.cuda.schedule_sparse_dense), + "x86": (topi.nn.sparse_dense, topi.x86.schedule_sparse_dense) +} + def verify_dynamic_csrmv(batch, in_dim, out_dim, use_bias=True): nr, nc, n = te.var("nr"), te.var("nc"), te.var("n") dtype = 'float32' @@ -293,16 +299,28 @@ def test_sparse_dense_bsr(): W_indices = te.placeholder(shape=W_sp_np.indices.shape, dtype=str(W_sp_np.indices.dtype)) W_indptr = te.placeholder(shape=W_sp_np.indptr.shape, dtype=str(W_sp_np.indptr.dtype)) X = te.placeholder(shape=X_np.shape, dtype=str(X_np.dtype)) - Y = topi.nn.sparse_dense(X, W_data, W_indices, W_indptr) - s = te.create_schedule(Y.op) - func = tvm.build(s, [X, W_data, W_indices, W_indptr, Y]) - Y_tvm = tvm.nd.array(np.zeros(Y_np.shape, dtype=Y_np.dtype)) - func(tvm.nd.array(X_np), - tvm.nd.array(W_sp_np.data), - tvm.nd.array(W_sp_np.indices), - tvm.nd.array(W_sp_np.indptr), - Y_tvm) - tvm.testing.assert_allclose(Y_tvm.asnumpy(), Y_np, atol=1e-4, rtol=1e-4) + + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + print("Running on target: %s" % device) + fcompute, fschedule = topi.testing.dispatch(device, _sparse_dense_implement) + with tvm.target.create(device): + Y = fcompute(X, W_data, W_indices, W_indptr) + s = fschedule([Y]) + func = tvm.build(s, [X, W_data, W_indices, W_indptr, Y]) + Y_tvm = tvm.nd.array(np.zeros(Y_np.shape, dtype=Y_np.dtype), ctx=ctx) + func(tvm.nd.array(X_np, ctx=ctx), + tvm.nd.array(W_sp_np.data, ctx=ctx), + tvm.nd.array(W_sp_np.indices, ctx=ctx), + tvm.nd.array(W_sp_np.indptr, ctx=ctx), + Y_tvm) + tvm.testing.assert_allclose(Y_tvm.asnumpy(), Y_np, atol=1e-4, rtol=1e-4) + + for device in ['llvm', 'cuda']: + check_device(device) def test_sparse_dense_bsr_randomized(): for _ in range(20): @@ -322,16 +340,28 @@ def test_sparse_dense_bsr_randomized(): W_indices = te.placeholder(shape=W_sp_np.indices.shape, dtype=str(W_sp_np.indices.dtype)) W_indptr = te.placeholder(shape=W_sp_np.indptr.shape, dtype=str(W_sp_np.indptr.dtype)) X = te.placeholder(shape=X_np.shape, dtype=str(X_np.dtype)) - Y = topi.nn.sparse_dense(X, W_data, W_indices, W_indptr) - s = te.create_schedule(Y.op) - func = tvm.build(s, [X, W_data, W_indices, W_indptr, Y]) - Y_tvm = tvm.nd.array(np.zeros(Y_np.shape, dtype=Y_np.dtype)) - func(tvm.nd.array(X_np), - tvm.nd.array(W_sp_np.data), - tvm.nd.array(W_sp_np.indices), - tvm.nd.array(W_sp_np.indptr), - Y_tvm) - tvm.testing.assert_allclose(Y_tvm.asnumpy(), Y_np, atol=1e-5, rtol=1e-5) + + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + print("Running on target: %s" % device) + fcompute, fschedule = topi.testing.dispatch(device, _sparse_dense_implement) + with tvm.target.create(device): + Y = fcompute(X, W_data, W_indices, W_indptr) + s = fschedule([Y]) + func = tvm.build(s, [X, W_data, W_indices, W_indptr, Y]) + Y_tvm = tvm.nd.array(np.zeros(Y_np.shape, dtype=Y_np.dtype), ctx=ctx) + func(tvm.nd.array(X_np, ctx=ctx), + tvm.nd.array(W_sp_np.data, ctx=ctx), + tvm.nd.array(W_sp_np.indices, ctx=ctx), + tvm.nd.array(W_sp_np.indptr, ctx=ctx), + Y_tvm) + tvm.testing.assert_allclose(Y_tvm.asnumpy(), Y_np, atol=1e-5, rtol=1e-5) + + for device in ['llvm', 'cuda']: + check_device(device) def test_sparse_dense():