Add rocblas_sgemm_strided_batched impl. (#6579)
authorChris Sullivan <csullivan@octoml.ai>
Tue, 29 Sep 2020 09:59:54 +0000 (02:59 -0700)
committerGitHub <noreply@github.com>
Tue, 29 Sep 2020 09:59:54 +0000 (18:59 +0900)
include/tvm/topi/contrib/rocblas.h
python/tvm/contrib/rocblas.py
python/tvm/relay/op/strategy/rocm.py
python/tvm/topi/rocm/__init__.py
python/tvm/topi/rocm/batch_matmul.py [new file with mode: 0644]
src/runtime/contrib/rocblas/rocblas.cc
tests/python/contrib/test_rocblas.py

index a4fa26f..4f0b887 100644 (file)
@@ -54,6 +54,29 @@ inline Tensor rocblas_matmul(const Tensor& lhs, const Tensor& rhs, bool transa,
       },
       "C", "", {})[0];
 }
+/*!
+ * \brief Create an op that batch multiplies lhs and rhs with rocBLAS
+ *
+ * \param lhs The left matrix operand  e.g. (batch_size, M, K)
+ * \param rhs The right matrix operand e.g. (batch_size, K, N)
+ * \param transa Whether to transpose lhs
+ * \param transb Whether to transpose rhs
+ *
+ * \return The output tensor
+ */
+inline Tensor rocblas_batch_matmul(const Tensor& lhs, const Tensor& rhs, bool transa, bool transb) {
+  auto batch_size = lhs->shape[0];
+  auto n = transa ? lhs->shape[2] : lhs->shape[1];
+  auto m = transb ? rhs->shape[1] : rhs->shape[2];
+
+  return make_extern(
+      {{batch_size, n, m}}, {lhs->dtype}, {lhs, rhs},
+      [&](Array<Buffer> ins, Array<Buffer> outs) {
+        return call_packed({StringImm("tvm.contrib.rocblas.batch_matmul"), pack_buffer(ins[0]),
+                            pack_buffer(ins[1]), pack_buffer(outs[0]), transa, transb});
+      },
+      "C", "", {})[0];
+}
 
 }  // namespace contrib
 }  // namespace topi
index 03ea2b5..70791dc 100644 (file)
@@ -48,3 +48,36 @@ def matmul(lhs, rhs, transa=False, transb=False):
         ),
         name="C",
     )
+
+
+def batch_matmul(lhs, rhs, transa=False, transb=False):
+    """Create an extern op that compute matrix mult of A and rhs with rocBLAS
+
+    Parameters
+    ----------
+    lhs : Tensor
+        The left batched matrix operand
+    rhs : Tensor
+        The right batched matrix operand
+    transa : bool
+        Whether transpose lhs
+    transb : bool
+        Whether transpose rhs
+
+    Returns
+    -------
+    C : Tensor
+        The result tensor.
+    """
+    batch_size = lhs.shape[0]
+    assert batch_size == rhs.shape[0]
+    n = lhs.shape[2] if transa else lhs.shape[1]
+    m = rhs.shape[1] if transb else rhs.shape[2]
+    return te.extern(
+        (batch_size, n, m),
+        [lhs, rhs],
+        lambda ins, outs: tvm.tir.call_packed(
+            "tvm.contrib.rocblas.batch_matmul", ins[0], ins[1], outs[0], transa, transb
+        ),
+        name="C",
+    )
index 2410260..f52bbc3 100644 (file)
@@ -160,3 +160,24 @@ def dense_strategy_rocm(attrs, inputs, out_type, target):
             plevel=15,
         )
     return strategy
+
+
+@batch_matmul_strategy.register("rocm")
+def batch_matmul_strategy_rocm(attrs, inputs, out_type, target):
+    """Batch matmul strategy for ROCM"""
+    strategy = _op.OpStrategy()
+    strategy.add_implementation(
+        wrap_compute_batch_matmul(topi.cuda.batch_matmul),
+        wrap_topi_schedule(topi.cuda.schedule_batch_matmul),
+        name="batch_matmul.cuda",
+        plevel=10,
+    )
+    if target.kind.name == "rocm" and "rocblas" in target.libs:
+        assert out_type.dtype == inputs[0].dtype, "Mixed precision not supported."
+        strategy.add_implementation(
+            wrap_compute_batch_matmul(topi.rocm.batch_matmul_rocblas),
+            wrap_topi_schedule(topi.rocm.schedule_batch_matmul_rocblas),
+            name="batch_matmul_rocblas.rocm",
+            plevel=12,
+        )
+    return strategy
index 4efdab4..1ea4c79 100644 (file)
@@ -19,6 +19,7 @@
 """rocm specific declaration and schedules."""
 from __future__ import absolute_import as _abs
 
+from .batch_matmul import *
 from .conv2d import *
 from .dense import *
 from .nn import *
diff --git a/python/tvm/topi/rocm/batch_matmul.py b/python/tvm/topi/rocm/batch_matmul.py
new file mode 100644 (file)
index 0000000..fa4dd45
--- /dev/null
@@ -0,0 +1,56 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-variable, unused-argument
+"""Schedule for batch_matmul operator"""
+from tvm import autotvm
+from tvm.contrib import rocblas
+from .. import generic
+from ..util import get_const_tuple
+
+
+@autotvm.register_topi_compute("batch_matmul_rocblas.rocm")
+def batch_matmul_rocblas(cfg, x, y, out_shape=None):
+    """Computes matrix multiplication of `x` and `y` via rocblas when
+    `x` and `y` are batched matrices.
+
+    Parameters
+    ----------
+    cfg : ConfigSpace
+        Autotvm tuning space config file
+    x : tvm.te.Tensor
+        3-D with shape [batch, M, K]
+    y : tvm.te.Tensor
+        3-D with shape [batch, N, K]
+    Returns
+    -------
+    output : tvm.te.Tensor
+        3-D with shape [batch, M, N]
+    """
+    batch, M, K = get_const_tuple(x.shape)
+    _, N, _ = get_const_tuple(y.shape)
+    if out_shape is not None:
+        assert out_shape[0] == batch, "Input and output batch sizes must match"
+        assert out_shape[1] == M and out_shape[2] == N, "Invalid output shape"
+    result = rocblas.batch_matmul(x, y, False, True)
+    cfg.add_flop(batch * M * N * K * 2)
+    return result
+
+
+@autotvm.register_topi_schedule("batch_matmul_rocblas.rocm")
+def schedule_batch_matmul_rocblas(_, outs):
+    """Schedule for batch_matmul operator with rocm cblas"""
+    return generic.schedule_extern(outs)
index 0e6f4bd..bca00a5 100644 (file)
@@ -70,18 +70,61 @@ TVM_REGISTER_GLOBAL("tvm.contrib.rocblas.matmul").set_body([](TVMArgs args, TVMR
   CHECK_ROCBLAS_ERROR(rocblas_create_handle(&handle));
   float alpha = 1.0;
   float beta = 0.0;
-  float* A_ptr = reinterpret_cast<float*>(static_cast<char*>(B->data) + B->byte_offset);
-  float* B_ptr = reinterpret_cast<float*>(static_cast<char*>(A->data) + A->byte_offset);
+  float* A_ptr = reinterpret_cast<float*>(static_cast<char*>(A->data) + A->byte_offset);
+  float* B_ptr = reinterpret_cast<float*>(static_cast<char*>(B->data) + B->byte_offset);
   float* C_ptr = reinterpret_cast<float*>(static_cast<char*>(C->data) + C->byte_offset);
 
-  CHECK_ROCBLAS_ERROR(
-      rocblas_sgemm(handle, transb ? rocblas_operation_transpose : rocblas_operation_none,
-                    transa ? rocblas_operation_transpose : rocblas_operation_none,
-                    transb ? B->shape[0] : B->shape[1], transa ? A->shape[1] : A->shape[0],
-                    transb ? B->shape[1] : B->shape[0], &alpha, A_ptr, B->shape[1], B_ptr,
-                    A->shape[1], &beta, C_ptr, C->shape[1]));
+  rocblas_operation roc_trans_A = transa ? rocblas_operation_transpose : rocblas_operation_none;
+  rocblas_operation roc_trans_B = transb ? rocblas_operation_transpose : rocblas_operation_none;
+  size_t N = transb ? B->shape[0] : B->shape[1];
+  size_t M = transa ? A->shape[1] : A->shape[0];
+  size_t K = transb ? B->shape[1] : B->shape[0];
+  size_t lda = transa ? M : K;
+  size_t ldb = transb ? K : N;
+  size_t ldc = N;
+
+  CHECK_ROCBLAS_ERROR(rocblas_sgemm(handle, roc_trans_B, roc_trans_A, N, M, K, &alpha, B_ptr, ldb,
+                                    A_ptr, lda, &beta, C_ptr, ldc));
 
   CHECK_ROCBLAS_ERROR(rocblas_destroy_handle(handle));
 });
+
+TVM_REGISTER_GLOBAL("tvm.contrib.rocblas.batch_matmul")
+    .set_body([](TVMArgs args, TVMRetValue* ret) {
+      DLTensor* A = args[0];
+      DLTensor* B = args[1];
+      DLTensor* C = args[2];
+      bool transa = args[3];
+      bool transb = args[4];
+      // call gemm for simple compact code.
+      CHECK_EQ(A->ndim, 3);
+      CHECK_EQ(B->ndim, 3);
+      CHECK_EQ(C->ndim, 3);
+      CHECK(TypeMatch(A->dtype, kDLFloat, 32));
+      CHECK(TypeMatch(B->dtype, kDLFloat, 32));
+      CHECK(TypeMatch(C->dtype, kDLFloat, 32));
+
+      rocblas_handle handle;
+      CHECK_ROCBLAS_ERROR(rocblas_create_handle(&handle));
+      float alpha = 1.0;
+      float beta = 0.0;
+      float* A_ptr = reinterpret_cast<float*>(static_cast<char*>(A->data) + A->byte_offset);
+      float* B_ptr = reinterpret_cast<float*>(static_cast<char*>(B->data) + B->byte_offset);
+      float* C_ptr = reinterpret_cast<float*>(static_cast<char*>(C->data) + C->byte_offset);
+
+      rocblas_operation roc_trans_A = transa ? rocblas_operation_transpose : rocblas_operation_none;
+      rocblas_operation roc_trans_B = transb ? rocblas_operation_transpose : rocblas_operation_none;
+      size_t batch_size = C->shape[0];
+      size_t N = transb ? B->shape[1] : B->shape[2];
+      size_t M = transa ? A->shape[2] : A->shape[1];
+      size_t K = transb ? B->shape[2] : B->shape[1];
+      size_t lda = transa ? M : K;
+      size_t ldb = transb ? K : N;
+      size_t ldc = N;
+
+      CHECK_ROCBLAS_ERROR(rocblas_sgemm_strided_batched(
+          handle, roc_trans_B, roc_trans_A, N, M, K, &alpha, B_ptr, ldb, K * N, A_ptr, lda, M * K,
+          &beta, C_ptr, ldc, M * N, batch_size));
+    });
 }  // namespace contrib
 }  // namespace tvm
index 9b8bacb..6f1783d 100644 (file)
@@ -18,11 +18,13 @@ import tvm
 import tvm.testing
 from tvm import te
 import numpy as np
+import tvm.topi.testing
+import tvm.testing
 from tvm.contrib import rocblas
 
 
 @tvm.testing.requires_rocm
-def test_matmul_add():
+def test_matmul():
     n = 1024
     l = 128
     m = 235
@@ -46,5 +48,65 @@ def test_matmul_add():
     verify()
 
 
+def verify_batch_matmul(batch, m, k, n, lib, transa=False, transb=False, dtype="float32"):
+    ashape = (batch, k, m) if transa else (batch, m, k)
+    bshape = (batch, n, k) if transb else (batch, k, n)
+    A = te.placeholder(ashape, name="A", dtype=dtype)
+    B = te.placeholder(bshape, name="B", dtype=dtype)
+    C = lib.batch_matmul(A, B, transa, transb)
+    s = te.create_schedule(C.op)
+
+    def get_numpy(a, b, transa, transb):
+        if transa:
+            a = a.transpose(0, 2, 1)
+        if not transb:
+            b = b.transpose(0, 2, 1)
+        return tvm.topi.testing.batch_matmul(a, b)
+
+    def verify(target="rocm"):
+        if not tvm.testing.device_enabled(target):
+            print("skip because %s is not enabled..." % target)
+            return
+        if not tvm.get_global_func(lib.__name__ + ".batch_matmul", True):
+            print("skip because extern function is not available")
+            return
+        ctx = tvm.rocm(0)
+        f = tvm.build(s, [A, B, C], target)
+        a = tvm.nd.array(np.random.uniform(size=ashape).astype(A.dtype), ctx)
+        b = tvm.nd.array(np.random.uniform(size=bshape).astype(B.dtype), ctx)
+        c = tvm.nd.array(np.zeros((batch, m, n), dtype=C.dtype), ctx)
+        f(a, b, c)
+        tvm.testing.assert_allclose(
+            c.asnumpy(), get_numpy(a.asnumpy(), b.asnumpy(), transa, transb), rtol=1e-5
+        )
+
+    verify()
+
+
+@tvm.testing.requires_rocm
+def test_batch_matmul():
+    verify_batch_matmul(128, 64, 512, 512, rocblas, transa=False, transb=False)
+    verify_batch_matmul(128, 64, 512, 512, rocblas, transa=False, transb=True)
+    verify_batch_matmul(128, 64, 512, 512, rocblas, transa=True, transb=False)
+    verify_batch_matmul(128, 64, 512, 512, rocblas, transa=True, transb=True)
+    verify_batch_matmul(128, 512, 512, 64, rocblas, transa=False, transb=False)
+    verify_batch_matmul(128, 512, 512, 64, rocblas, transa=False, transb=True)
+    verify_batch_matmul(128, 512, 512, 64, rocblas, transa=True, transb=False)
+    verify_batch_matmul(128, 512, 512, 64, rocblas, transa=True, transb=True)
+    verify_batch_matmul(128, 512, 64, 512, rocblas, transa=False, transb=False)
+    verify_batch_matmul(128, 512, 64, 512, rocblas, transa=False, transb=True)
+    verify_batch_matmul(128, 512, 64, 512, rocblas, transa=True, transb=False)
+    verify_batch_matmul(128, 512, 64, 512, rocblas, transa=True, transb=True)
+    verify_batch_matmul(128, 64, 128, 128, rocblas, transa=False, transb=False)
+    verify_batch_matmul(128, 64, 128, 128, rocblas, transa=False, transb=True)
+    verify_batch_matmul(128, 64, 128, 128, rocblas, transa=True, transb=False)
+    verify_batch_matmul(128, 64, 128, 128, rocblas, transa=True, transb=True)
+    verify_batch_matmul(128, 128, 128, 64, rocblas, transa=False, transb=False)
+    verify_batch_matmul(128, 128, 128, 64, rocblas, transa=False, transb=True)
+    verify_batch_matmul(128, 128, 128, 64, rocblas, transa=True, transb=False)
+    verify_batch_matmul(128, 128, 128, 64, rocblas, transa=True, transb=True)
+
+
 if __name__ == "__main__":
-    test_matmul_add()
+    test_matmul()
+    test_batch_matmul()