},
"C", "", {})[0];
}
+/*!
+ * \brief Create an op that batch multiplies lhs and rhs with rocBLAS
+ *
+ * \param lhs The left matrix operand e.g. (batch_size, M, K)
+ * \param rhs The right matrix operand e.g. (batch_size, K, N)
+ * \param transa Whether to transpose lhs
+ * \param transb Whether to transpose rhs
+ *
+ * \return The output tensor
+ */
+inline Tensor rocblas_batch_matmul(const Tensor& lhs, const Tensor& rhs, bool transa, bool transb) {
+ auto batch_size = lhs->shape[0];
+ auto n = transa ? lhs->shape[2] : lhs->shape[1];
+ auto m = transb ? rhs->shape[1] : rhs->shape[2];
+
+ return make_extern(
+ {{batch_size, n, m}}, {lhs->dtype}, {lhs, rhs},
+ [&](Array<Buffer> ins, Array<Buffer> outs) {
+ return call_packed({StringImm("tvm.contrib.rocblas.batch_matmul"), pack_buffer(ins[0]),
+ pack_buffer(ins[1]), pack_buffer(outs[0]), transa, transb});
+ },
+ "C", "", {})[0];
+}
} // namespace contrib
} // namespace topi
),
name="C",
)
+
+
+def batch_matmul(lhs, rhs, transa=False, transb=False):
+ """Create an extern op that compute matrix mult of A and rhs with rocBLAS
+
+ Parameters
+ ----------
+ lhs : Tensor
+ The left batched matrix operand
+ rhs : Tensor
+ The right batched matrix operand
+ transa : bool
+ Whether transpose lhs
+ transb : bool
+ Whether transpose rhs
+
+ Returns
+ -------
+ C : Tensor
+ The result tensor.
+ """
+ batch_size = lhs.shape[0]
+ assert batch_size == rhs.shape[0]
+ n = lhs.shape[2] if transa else lhs.shape[1]
+ m = rhs.shape[1] if transb else rhs.shape[2]
+ return te.extern(
+ (batch_size, n, m),
+ [lhs, rhs],
+ lambda ins, outs: tvm.tir.call_packed(
+ "tvm.contrib.rocblas.batch_matmul", ins[0], ins[1], outs[0], transa, transb
+ ),
+ name="C",
+ )
plevel=15,
)
return strategy
+
+
+@batch_matmul_strategy.register("rocm")
+def batch_matmul_strategy_rocm(attrs, inputs, out_type, target):
+ """Batch matmul strategy for ROCM"""
+ strategy = _op.OpStrategy()
+ strategy.add_implementation(
+ wrap_compute_batch_matmul(topi.cuda.batch_matmul),
+ wrap_topi_schedule(topi.cuda.schedule_batch_matmul),
+ name="batch_matmul.cuda",
+ plevel=10,
+ )
+ if target.kind.name == "rocm" and "rocblas" in target.libs:
+ assert out_type.dtype == inputs[0].dtype, "Mixed precision not supported."
+ strategy.add_implementation(
+ wrap_compute_batch_matmul(topi.rocm.batch_matmul_rocblas),
+ wrap_topi_schedule(topi.rocm.schedule_batch_matmul_rocblas),
+ name="batch_matmul_rocblas.rocm",
+ plevel=12,
+ )
+ return strategy
"""rocm specific declaration and schedules."""
from __future__ import absolute_import as _abs
+from .batch_matmul import *
from .conv2d import *
from .dense import *
from .nn import *
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-variable, unused-argument
+"""Schedule for batch_matmul operator"""
+from tvm import autotvm
+from tvm.contrib import rocblas
+from .. import generic
+from ..util import get_const_tuple
+
+
+@autotvm.register_topi_compute("batch_matmul_rocblas.rocm")
+def batch_matmul_rocblas(cfg, x, y, out_shape=None):
+ """Computes matrix multiplication of `x` and `y` via rocblas when
+ `x` and `y` are batched matrices.
+
+ Parameters
+ ----------
+ cfg : ConfigSpace
+ Autotvm tuning space config file
+ x : tvm.te.Tensor
+ 3-D with shape [batch, M, K]
+ y : tvm.te.Tensor
+ 3-D with shape [batch, N, K]
+ Returns
+ -------
+ output : tvm.te.Tensor
+ 3-D with shape [batch, M, N]
+ """
+ batch, M, K = get_const_tuple(x.shape)
+ _, N, _ = get_const_tuple(y.shape)
+ if out_shape is not None:
+ assert out_shape[0] == batch, "Input and output batch sizes must match"
+ assert out_shape[1] == M and out_shape[2] == N, "Invalid output shape"
+ result = rocblas.batch_matmul(x, y, False, True)
+ cfg.add_flop(batch * M * N * K * 2)
+ return result
+
+
+@autotvm.register_topi_schedule("batch_matmul_rocblas.rocm")
+def schedule_batch_matmul_rocblas(_, outs):
+ """Schedule for batch_matmul operator with rocm cblas"""
+ return generic.schedule_extern(outs)
CHECK_ROCBLAS_ERROR(rocblas_create_handle(&handle));
float alpha = 1.0;
float beta = 0.0;
- float* A_ptr = reinterpret_cast<float*>(static_cast<char*>(B->data) + B->byte_offset);
- float* B_ptr = reinterpret_cast<float*>(static_cast<char*>(A->data) + A->byte_offset);
+ float* A_ptr = reinterpret_cast<float*>(static_cast<char*>(A->data) + A->byte_offset);
+ float* B_ptr = reinterpret_cast<float*>(static_cast<char*>(B->data) + B->byte_offset);
float* C_ptr = reinterpret_cast<float*>(static_cast<char*>(C->data) + C->byte_offset);
- CHECK_ROCBLAS_ERROR(
- rocblas_sgemm(handle, transb ? rocblas_operation_transpose : rocblas_operation_none,
- transa ? rocblas_operation_transpose : rocblas_operation_none,
- transb ? B->shape[0] : B->shape[1], transa ? A->shape[1] : A->shape[0],
- transb ? B->shape[1] : B->shape[0], &alpha, A_ptr, B->shape[1], B_ptr,
- A->shape[1], &beta, C_ptr, C->shape[1]));
+ rocblas_operation roc_trans_A = transa ? rocblas_operation_transpose : rocblas_operation_none;
+ rocblas_operation roc_trans_B = transb ? rocblas_operation_transpose : rocblas_operation_none;
+ size_t N = transb ? B->shape[0] : B->shape[1];
+ size_t M = transa ? A->shape[1] : A->shape[0];
+ size_t K = transb ? B->shape[1] : B->shape[0];
+ size_t lda = transa ? M : K;
+ size_t ldb = transb ? K : N;
+ size_t ldc = N;
+
+ CHECK_ROCBLAS_ERROR(rocblas_sgemm(handle, roc_trans_B, roc_trans_A, N, M, K, &alpha, B_ptr, ldb,
+ A_ptr, lda, &beta, C_ptr, ldc));
CHECK_ROCBLAS_ERROR(rocblas_destroy_handle(handle));
});
+
+TVM_REGISTER_GLOBAL("tvm.contrib.rocblas.batch_matmul")
+ .set_body([](TVMArgs args, TVMRetValue* ret) {
+ DLTensor* A = args[0];
+ DLTensor* B = args[1];
+ DLTensor* C = args[2];
+ bool transa = args[3];
+ bool transb = args[4];
+ // call gemm for simple compact code.
+ CHECK_EQ(A->ndim, 3);
+ CHECK_EQ(B->ndim, 3);
+ CHECK_EQ(C->ndim, 3);
+ CHECK(TypeMatch(A->dtype, kDLFloat, 32));
+ CHECK(TypeMatch(B->dtype, kDLFloat, 32));
+ CHECK(TypeMatch(C->dtype, kDLFloat, 32));
+
+ rocblas_handle handle;
+ CHECK_ROCBLAS_ERROR(rocblas_create_handle(&handle));
+ float alpha = 1.0;
+ float beta = 0.0;
+ float* A_ptr = reinterpret_cast<float*>(static_cast<char*>(A->data) + A->byte_offset);
+ float* B_ptr = reinterpret_cast<float*>(static_cast<char*>(B->data) + B->byte_offset);
+ float* C_ptr = reinterpret_cast<float*>(static_cast<char*>(C->data) + C->byte_offset);
+
+ rocblas_operation roc_trans_A = transa ? rocblas_operation_transpose : rocblas_operation_none;
+ rocblas_operation roc_trans_B = transb ? rocblas_operation_transpose : rocblas_operation_none;
+ size_t batch_size = C->shape[0];
+ size_t N = transb ? B->shape[1] : B->shape[2];
+ size_t M = transa ? A->shape[2] : A->shape[1];
+ size_t K = transb ? B->shape[2] : B->shape[1];
+ size_t lda = transa ? M : K;
+ size_t ldb = transb ? K : N;
+ size_t ldc = N;
+
+ CHECK_ROCBLAS_ERROR(rocblas_sgemm_strided_batched(
+ handle, roc_trans_B, roc_trans_A, N, M, K, &alpha, B_ptr, ldb, K * N, A_ptr, lda, M * K,
+ &beta, C_ptr, ldc, M * N, batch_size));
+ });
} // namespace contrib
} // namespace tvm
import tvm.testing
from tvm import te
import numpy as np
+import tvm.topi.testing
+import tvm.testing
from tvm.contrib import rocblas
@tvm.testing.requires_rocm
-def test_matmul_add():
+def test_matmul():
n = 1024
l = 128
m = 235
verify()
+def verify_batch_matmul(batch, m, k, n, lib, transa=False, transb=False, dtype="float32"):
+ ashape = (batch, k, m) if transa else (batch, m, k)
+ bshape = (batch, n, k) if transb else (batch, k, n)
+ A = te.placeholder(ashape, name="A", dtype=dtype)
+ B = te.placeholder(bshape, name="B", dtype=dtype)
+ C = lib.batch_matmul(A, B, transa, transb)
+ s = te.create_schedule(C.op)
+
+ def get_numpy(a, b, transa, transb):
+ if transa:
+ a = a.transpose(0, 2, 1)
+ if not transb:
+ b = b.transpose(0, 2, 1)
+ return tvm.topi.testing.batch_matmul(a, b)
+
+ def verify(target="rocm"):
+ if not tvm.testing.device_enabled(target):
+ print("skip because %s is not enabled..." % target)
+ return
+ if not tvm.get_global_func(lib.__name__ + ".batch_matmul", True):
+ print("skip because extern function is not available")
+ return
+ ctx = tvm.rocm(0)
+ f = tvm.build(s, [A, B, C], target)
+ a = tvm.nd.array(np.random.uniform(size=ashape).astype(A.dtype), ctx)
+ b = tvm.nd.array(np.random.uniform(size=bshape).astype(B.dtype), ctx)
+ c = tvm.nd.array(np.zeros((batch, m, n), dtype=C.dtype), ctx)
+ f(a, b, c)
+ tvm.testing.assert_allclose(
+ c.asnumpy(), get_numpy(a.asnumpy(), b.asnumpy(), transa, transb), rtol=1e-5
+ )
+
+ verify()
+
+
+@tvm.testing.requires_rocm
+def test_batch_matmul():
+ verify_batch_matmul(128, 64, 512, 512, rocblas, transa=False, transb=False)
+ verify_batch_matmul(128, 64, 512, 512, rocblas, transa=False, transb=True)
+ verify_batch_matmul(128, 64, 512, 512, rocblas, transa=True, transb=False)
+ verify_batch_matmul(128, 64, 512, 512, rocblas, transa=True, transb=True)
+ verify_batch_matmul(128, 512, 512, 64, rocblas, transa=False, transb=False)
+ verify_batch_matmul(128, 512, 512, 64, rocblas, transa=False, transb=True)
+ verify_batch_matmul(128, 512, 512, 64, rocblas, transa=True, transb=False)
+ verify_batch_matmul(128, 512, 512, 64, rocblas, transa=True, transb=True)
+ verify_batch_matmul(128, 512, 64, 512, rocblas, transa=False, transb=False)
+ verify_batch_matmul(128, 512, 64, 512, rocblas, transa=False, transb=True)
+ verify_batch_matmul(128, 512, 64, 512, rocblas, transa=True, transb=False)
+ verify_batch_matmul(128, 512, 64, 512, rocblas, transa=True, transb=True)
+ verify_batch_matmul(128, 64, 128, 128, rocblas, transa=False, transb=False)
+ verify_batch_matmul(128, 64, 128, 128, rocblas, transa=False, transb=True)
+ verify_batch_matmul(128, 64, 128, 128, rocblas, transa=True, transb=False)
+ verify_batch_matmul(128, 64, 128, 128, rocblas, transa=True, transb=True)
+ verify_batch_matmul(128, 128, 128, 64, rocblas, transa=False, transb=False)
+ verify_batch_matmul(128, 128, 128, 64, rocblas, transa=False, transb=True)
+ verify_batch_matmul(128, 128, 128, 64, rocblas, transa=True, transb=False)
+ verify_batch_matmul(128, 128, 128, 64, rocblas, transa=True, transb=True)
+
+
if __name__ == "__main__":
- test_matmul_add()
+ test_matmul()
+ test_batch_matmul()