From 8cd53e00722d079290b53ade348b860f9c237ee9 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Tue, 28 Jul 2020 13:32:00 -0700 Subject: [PATCH] [Topi, x86] Using MKL blas for quantized dense (#6115) * [Topi, x86] Using MKL blas for quantized dense * Typo * CBLAS_OFFSET only available for MKL * Skipping tests as GPU CI uses Openblas * Retrigger Co-authored-by: Ubuntu --- python/tvm/contrib/cblas.py | 33 +++++++++++++++++++++ src/runtime/contrib/cblas/cblas.cc | 41 ++++++++++++++++++++++++++ src/runtime/contrib/cblas/gemm_common.h | 48 ++++++++++++++++++++++++++++++ tests/python/contrib/test_cblas.py | 52 +++++++++++++++++++++++++++++++++ topi/python/topi/x86/dense.py | 8 ++++- 5 files changed, 181 insertions(+), 1 deletion(-) diff --git a/python/tvm/contrib/cblas.py b/python/tvm/contrib/cblas.py index e1a4a8a..68586df 100644 --- a/python/tvm/contrib/cblas.py +++ b/python/tvm/contrib/cblas.py @@ -52,6 +52,39 @@ def matmul(lhs, rhs, transa=False, transb=False, **kwargs): ) +def matmul_u8s8s32(lhs, rhs, transa=False, transb=False, **kwargs): + """Create an extern op that compute matrix mult of A and rhs with CrhsLAS + This function serves as an example on how to call external libraries. + + Parameters + ---------- + lhs: Tensor + The left matrix operand + rhs: Tensor + The right matrix operand + transa: bool + Whether transpose lhs + transb: bool + Whether transpose rhs + + Returns + ------- + C: Tensor + The result tensor. + """ + n = lhs.shape[1] if transa else lhs.shape[0] + m = rhs.shape[0] if transb else rhs.shape[1] + return te.extern( + (n, m), + [lhs, rhs], + lambda ins, outs: tvm.tir.call_packed( + "tvm.contrib.cblas.matmul_u8s8s32", ins[0], ins[1], outs[0], transa, transb + ), + name="C", + **kwargs + ) + + def batch_matmul(lhs, rhs, transa=False, transb=False, iterative=False, **kwargs): """Create an extern op that compute batched matrix mult of A and rhs with CBLAS This function serves as an example on how to call external libraries. diff --git a/src/runtime/contrib/cblas/cblas.cc b/src/runtime/contrib/cblas/cblas.cc index 0cf4c69..e84ee11 100644 --- a/src/runtime/contrib/cblas/cblas.cc +++ b/src/runtime/contrib/cblas/cblas.cc @@ -44,8 +44,37 @@ using namespace runtime; inline CBLAS_TRANSPOSE BooleanToTranspose(bool trans) { return trans ? CblasTrans : CblasNoTrans; } +#if USE_MKL_BLAS == 1 +inline CBLAS_OFFSET StringToOffset(const std::string offset_type) { + if (offset_type != "CblasFixOffset" && offset_type != "CblasColOffset" && + offset_type != "CblasRowOffset") { + LOG(FATAL) << "Unrecognized offset_type " << offset_type; + } + if (offset_type == "CblasFixOffset") { + return CblasFixOffset; + } else if (offset_type == "CblasColOffset") { + return CblasColOffset; + } + return CblasRowOffset; +} +#endif + inline char BooleanToTransposeChar(bool trans) { return trans ? 'T' : 'N'; } +struct CblasGemmU8S8S32Op { + void operator()(bool ta, bool tb, int M, int N, int K, float alpha, const void* A, int lda, + int offset_a, const void* B, int ldb, int offset_b, float beta, int* C, int ldc, + const std::string offset_ctype, int* offset_c) { +#if USE_MKL_BLAS == 1 + cblas_gemm_s8u8s32(CblasColMajor, BooleanToTranspose(ta), BooleanToTranspose(tb), + StringToOffset(offset_ctype), M, N, K, alpha, A, lda, offset_a, B, ldb, + offset_b, beta, C, ldc, offset_c); +#else + LOG(FATAL) << "Quantized Gemm is supported with MKL Blas only"; +#endif + } +}; + struct CblasSgemmOp { typedef float TDatatype; void operator()(bool ta, bool tb, int M, int N, int K, float alpha, float* A, int lda, float* B, @@ -170,6 +199,18 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cblas.matmul").set_body([](TVMArgs args, TVMRet CallGemm(args, ret, CblasDgemmOp()); }); +// integer matrix multiplication for row major +TVM_REGISTER_GLOBAL("tvm.contrib.cblas.matmul_u8s8s32") + .set_body([](TVMArgs args, TVMRetValue* ret) { + DLTensor* A = args[0]; + DLTensor* B = args[1]; + DLTensor* C = args[2]; + CHECK(TypeMatch(A->dtype, kDLUInt, 8) && TypeMatch(B->dtype, kDLInt, 8) && + TypeMatch(C->dtype, kDLInt, 32)); + + CallU8S8S32Gemm(args, ret, CblasGemmU8S8S32Op()); + }); + TVM_REGISTER_GLOBAL("tvm.contrib.cblas.batch_matmul").set_body([](TVMArgs args, TVMRetValue* ret) { DLTensor* A = args[0]; CHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); diff --git a/src/runtime/contrib/cblas/gemm_common.h b/src/runtime/contrib/cblas/gemm_common.h index 96d6322..d92f9d7 100644 --- a/src/runtime/contrib/cblas/gemm_common.h +++ b/src/runtime/contrib/cblas/gemm_common.h @@ -27,6 +27,7 @@ #include #include +#include namespace tvm { namespace contrib { @@ -99,6 +100,53 @@ inline void CallGemm(TVMArgs args, TVMRetValue* ret, TGemmOp op) { ColumnStride(C)); } +// Call a column major blas. Note that data is stored in tvm as row +// major, so this we switch the arguments. +template +inline void CallU8S8S32Gemm(TVMArgs args, TVMRetValue* ret, TGemmOp op) { + DLTensor* A = args[0]; + DLTensor* B = args[1]; + DLTensor* C = args[2]; + bool transa = args[3]; + bool transb = args[4]; + + // Set the sgemm attributes. Currently, support is limited to CblasFixOffset with all offsets + // equal to 0. This is sufficient for relay dense. + std::string offset_ctype = "CblasFixOffset"; + int16_t offset_a = 0; + int16_t offset_b = 0; + int offset_c[1]; + offset_c[0] = 0; + + CHECK_EQ(A->ndim, 2); + CHECK_EQ(B->ndim, 2); + CHECK_EQ(C->ndim, 2); + + CHECK_EQ(ElementStride(A), 1); + CHECK_EQ(ElementStride(B), 1); + CHECK_EQ(ElementStride(C), 1); + + // C can never be transposed. + CHECK(!IsInPlaceTransposed(C)); + + // Reversed strides indicates an in-place transpose operation. + transa = IsInPlaceTransposed(A) ? !transa : transa; + transb = IsInPlaceTransposed(B) ? !transb : transb; + + CHECK(TypeMatch(A->dtype, kDLUInt, 8)); + CHECK(TypeMatch(B->dtype, kDLInt, 8)); + CHECK(TypeMatch(C->dtype, kDLInt, 32)); + double alpha = args.size() > 5 ? args[5] : 1.0; + double beta = args.size() > 6 ? args[6] : 0.0; + op(transb, transa, ColumnCount(B, transb), RowCount(A, transa), ColumnCount(A, transa), + static_cast(alpha), + reinterpret_cast(static_cast(B->data) + B->byte_offset), ColumnStride(B), + offset_b, reinterpret_cast(static_cast(A->data) + A->byte_offset), + ColumnStride(A), offset_a, static_cast(beta), + reinterpret_cast(static_cast(C->data) + C->byte_offset), ColumnStride(C), + offset_ctype, offset_c); +} + inline int ColumnStride3D(DLTensor* tensor) { // If the tensor itself is transposed then it will have strides // backward from what we expect. Regardless, the max of the strides diff --git a/tests/python/contrib/test_cblas.py b/tests/python/contrib/test_cblas.py index 18ea57a..54f4ff6 100644 --- a/tests/python/contrib/test_cblas.py +++ b/tests/python/contrib/test_cblas.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import pytest import tvm from tvm import te import numpy as np @@ -65,6 +66,57 @@ def test_matmul_add(): verify_matmul_add(1, 16, 3, False, False) verify_matmul_add(1, 16, 3, True, True) +def verify_quantized_matmul_add(m, l, n, transa=False, transb=False): + pytest.skip("Quantized dense is supported only for MKL. TVM GPU CI uses openblas") + data_dtype = "uint8" + kernel_dtype = "int8" + out_dtype = "int32" + bias = te.var('bias', dtype=out_dtype) + ashape = (l, n) if transa else (n, l) + bshape = (m, l) if transb else (l, m) + A = te.placeholder(ashape, name='A', dtype=data_dtype) + B = te.placeholder(bshape, name='B', dtype=kernel_dtype) + C = cblas.matmul_u8s8s32(A, B, transa, transb, dtype=out_dtype) + D = te.compute(C.shape, lambda i, j: C[i,j] + bias, name="D") + s = te.create_schedule(D.op) + + def get_numpy(a, b, bb, transa, transb): + if transa: + a = a.transpose() + if transb: + b = b.transpose() + return np.dot(a, b) + bb + + def verify(target="llvm"): + if not tvm.runtime.enabled(target): + print("skip because %s is not enabled..." % target) + return + if not tvm.get_global_func("tvm.contrib.cblas.matmul_u8s8s32", True): + print("skip because extern function is not available") + return + ctx = tvm.cpu(0) + f = tvm.build(s, [A, B, D, bias], target) + a = tvm.nd.array(np.random.randint(low=0, high=50, size=ashape).astype(A.dtype), ctx) + b = tvm.nd.array(np.random.randint(low=0, high=50, size=bshape).astype(B.dtype), ctx) + d = tvm.nd.array(np.zeros((n, m), dtype=D.dtype), ctx) + bb = 10 + f(a, b, d, bb) + tvm.testing.assert_allclose( + d.asnumpy(), + get_numpy(a.asnumpy().astype('int32'), b.asnumpy().astype('int32'), bb, transa, transb), + rtol=1e-5) + verify() + +def test_quantized_matmul_add(): + verify_quantized_matmul_add(235, 128, 1024) + verify_quantized_matmul_add(235, 128, 1024, True, False) + verify_quantized_matmul_add(235, 128, 1024, False, True) + verify_quantized_matmul_add(235, 128, 1024, True, True) + verify_quantized_matmul_add(1, 16, 4) + verify_quantized_matmul_add(1, 16, 3, True, False) + verify_quantized_matmul_add(1, 16, 3, False, True) + verify_quantized_matmul_add(1, 16, 3, True, True) + def verify_batch_matmul(batch, m, l, n, transa=False, transb=False, iterative=False, dtype="float32"): ashape = (batch, l, n) if transa else (batch, n, l) bshape = (batch, m, l) if transb else (batch, l, m) diff --git a/topi/python/topi/x86/dense.py b/topi/python/topi/x86/dense.py index 3e99d06..500fcfc 100644 --- a/topi/python/topi/x86/dense.py +++ b/topi/python/topi/x86/dense.py @@ -226,7 +226,13 @@ def dense_cblas(cfg, data, weight, bias=None, out_dtype=None): M, K = get_const_tuple(data.shape) N, _ = get_const_tuple(weight.shape) cfg.add_flop(M * K * N * 2) - C = cblas.matmul(data, weight, False, True) + if data.dtype == 'uint8' and weight.dtype == 'int8' and out_dtype == 'int32': + C = cblas.matmul_u8s8s32(data, weight, False, True, dtype=out_dtype) + elif data.dtype == 'float32': + C = cblas.matmul(data, weight, False, True) + else: + raise NotImplementedError(f"Dense with cblas for {data.dtype} is not supported") + if bias is not None: C = te.compute(C.shape, lambda i, j: C[i, j] + bias[j].astype(out_dtype), tag=tag.BROADCAST) -- 2.7.4