[Contrib] cblas batch_matmul (#3210)
authorhlu1 <14827759+hlu1@users.noreply.github.com>
Tue, 21 May 2019 23:05:28 +0000 (16:05 -0700)
committerLeyuan Wang <laurawly@gmail.com>
Tue, 21 May 2019 23:05:28 +0000 (16:05 -0700)
cmake/modules/contrib/BLAS.cmake
python/tvm/contrib/cblas.py
src/contrib/cblas/cblas.cc
src/contrib/cblas/gemm_common.h
tests/python/contrib/test_cblas.py

index e1e151d..a47f837 100644 (file)
@@ -27,7 +27,11 @@ elseif(USE_BLAS STREQUAL "mkl")
   if(NOT IS_DIRECTORY ${USE_MKL_PATH})
     set(USE_MKL_PATH /opt/intel/mkl)
   endif()
-  find_library(BLAS_LIBRARY NAMES mkl_rt mklml_gnu HINTS ${USE_MKL_PATH}/lib/ ${USE_MKL_PATH}/lib/intel64)
+  if(APPLE)
+    find_library(BLAS_LIBRARY NAMES mklml HINTS ${USE_MKL_PATH}/lib/ ${USE_MKL_PATH}/lib/intel64)
+  elseif(UNIX)
+    find_library(BLAS_LIBRARY NAMES mkl_rt mklml_gnu HINTS ${USE_MKL_PATH}/lib/ ${USE_MKL_PATH}/lib/intel64)
+  endif()
   include_directories(${USE_MKL_PATH}/include)
   list(APPEND TVM_RUNTIME_LINKER_LIBS ${BLAS_LIBRARY})
   list(APPEND RUNTIME_SRCS ${CBLAS_CONTRIB_SRC})
index c656fcc..7c024b7 100644 (file)
 """External function interface to BLAS libraries."""
 from __future__ import absolute_import as _abs
 
-from .. import api as _api
-from .. import intrin as _intrin
+from .. import api as _api, intrin as _intrin
 
-def matmul(lhs, rhs, transa=False, transb=False):
+
+def matmul(lhs, rhs, transa=False, transb=False, **kwargs):
     """Create an extern op that compute matrix mult of A and rhs with CrhsLAS
 
     This function serves as an example on how to call external libraries.
@@ -44,7 +44,50 @@ def matmul(lhs, rhs, transa=False, transb=False):
     n = lhs.shape[1] if transa else lhs.shape[0]
     m = rhs.shape[0] if transb else rhs.shape[1]
     return _api.extern(
-        (n, m), [lhs, rhs],
+        (n, m),
+        [lhs, rhs],
+        lambda ins, outs: _intrin.call_packed(
+            "tvm.contrib.cblas.matmul", ins[0], ins[1], outs[0], transa, transb
+        ),
+        name="C",
+        **kwargs
+    )
+
+
+def batch_matmul(lhs, rhs, transa=False, transb=False, iterative=False, **kwargs):
+    """Create an extern op that compute batched matrix mult of A and rhs with CBLAS
+     This function serves as an example on how to call external libraries.
+     Parameters
+    ----------
+    lhs : Tensor
+        The left matrix operand
+    rhs : Tensor
+        The right matrix operand
+    transa : bool
+        Whether transpose lhs
+    transb : bool
+        Whether transpose rhs
+     Returns
+    -------
+    C : Tensor
+        The result tensor.
+    """
+    b = lhs.shape[0]
+    n = lhs.shape[2] if transa else lhs.shape[1]
+    m = rhs.shape[1] if transb else rhs.shape[2]
+    return _api.extern(
+        (b, n, m),
+        [lhs, rhs],
         lambda ins, outs: _intrin.call_packed(
-            "tvm.contrib.cblas.matmul",
-            ins[0], ins[1], outs[0], transa, transb), name="C")
+            "tvm.contrib.cblas.batch_matmul"
+            if not iterative
+            else "tvm.contrib.cblas.batch_matmul_iterative",
+            ins[0],
+            ins[1],
+            outs[0],
+            transa,
+            transb,
+        ),
+        name="C",
+        **kwargs
+    )
index 4ca043f..0f222e2 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  *  Copyright (c) 2017 by Contributors
  * \file Use external cblas library call.
  */
+#include <dmlc/logging.h>
 #include <tvm/runtime/registry.h>
 #include <tvm/runtime/util.h>
-#include <dmlc/logging.h>
 #include "gemm_common.h"
 
-
 extern "C" {
 #if USE_MKL_BLAS == 1
 #include <mkl_cblas.h>
@@ -40,56 +39,148 @@ namespace contrib {
 
 using namespace runtime;
 
-inline CBLAS_TRANSPOSE BooleanToTranspose(bool trans) {
-  return trans ? CblasTrans : CblasNoTrans;
-}
+inline CBLAS_TRANSPOSE BooleanToTranspose(bool trans) { return trans ? CblasTrans : CblasNoTrans; }
 
 struct CblasSgemmOp {
   typedef float TDatatype;
-  void operator()(bool ta, bool tb,
-                  int M, int N, int K,
-                  float alpha, float* A, int lda,
-                  float* B, int ldb,
-                  float beta, float* C, int ldc) {
-    cblas_sgemm(CblasColMajor,
-                BooleanToTranspose(ta),
-                BooleanToTranspose(tb),
-                M, N, K,
-                alpha, A, lda,
-                B, ldb,
-                beta, C, ldc);
+  void operator()(bool ta, bool tb, int M, int N, int K, float alpha, float* A, int lda, float* B,
+                  int ldb, float beta, float* C, int ldc) {
+    cblas_sgemm(CblasColMajor, BooleanToTranspose(ta), BooleanToTranspose(tb), M, N, K, alpha, A,
+                lda, B, ldb, beta, C, ldc);
   }
 };
 
 struct CblasDgemmOp {
   typedef double TDatatype;
-  void operator()(bool ta, bool tb,
-                  int M, int N, int K,
-                  double alpha, double* A, int lda,
-                  double* B, int ldb,
-                  double beta, double* C, int ldc) {
-    cblas_dgemm(CblasColMajor,
-                BooleanToTranspose(ta),
-                BooleanToTranspose(tb),
-                M, N, K,
-                alpha, A, lda,
-                B, ldb,
-                beta, C, ldc);
+  void operator()(bool ta, bool tb, int M, int N, int K, double alpha, double* A, int lda,
+                  double* B, int ldb, double beta, double* C, int ldc) {
+    cblas_dgemm(CblasColMajor, BooleanToTranspose(ta), BooleanToTranspose(tb), M, N, K, alpha, A,
+                lda, B, ldb, beta, C, ldc);
   }
 };
 
+struct CblasSgemmBatchOp {
+  typedef float TDatatype;
+  void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, float alpha, float* A,
+                  int a_stride, int lda, float* B, int b_stride, int ldb, float beta, float* C,
+                  int c_stride, int ldc) {
+    CBLAS_TRANSPOSE trans_a = BooleanToTranspose(ta);
+    CBLAS_TRANSPOSE trans_b = BooleanToTranspose(tb);
+#if USE_MKL_BLAS == 1
+    std::vector<const float*> A_array(batch_size);
+    std::vector<const float*> B_array(batch_size);
+    std::vector<float*> C_array(batch_size);
+    for (int i = 0; i < batch_size; ++i) {
+      A_array[i] = A + i * a_stride;
+      B_array[i] = B + i * b_stride;
+      C_array[i] = C + i * c_stride;
+    }
+    cblas_sgemm_batch(CblasColMajor, &trans_a, &trans_b, &M, &N, &K, &alpha, A_array.data(), &lda,
+                      B_array.data(), &ldb, &beta, C_array.data(), &ldc, 1, &batch_size);
+#else
+    for (int i = 0; i < batch_size; ++i) {
+      cblas_sgemm(CblasColMajor, trans_a, trans_b, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
+      A += a_stride;
+      B += b_stride;
+      C += c_stride;
+    }
+#endif
+  }
+};
+
+struct CblasSgemmBatchIterativeOp {
+  typedef float TDatatype;
+  void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, float alpha, float* A,
+                  int a_stride, int lda, float* B, int b_stride, int ldb, float beta, float* C,
+                  int c_stride, int ldc) {
+    CBLAS_TRANSPOSE trans_a = BooleanToTranspose(ta);
+    CBLAS_TRANSPOSE trans_b = BooleanToTranspose(tb);
+    for (int i = 0; i < batch_size; ++i) {
+      cblas_sgemm(CblasColMajor, trans_a, trans_b, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
+      A += a_stride;
+      B += b_stride;
+      C += c_stride;
+    }
+  }
+};
+
+struct CblasDgemmBatchOp {
+  typedef double TDatatype;
+  void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, double alpha, double* A,
+                  int a_stride, int lda, double* B, int b_stride, int ldb, double beta, double* C,
+                  int c_stride, int ldc) {
+    CBLAS_TRANSPOSE trans_a = BooleanToTranspose(ta);
+    CBLAS_TRANSPOSE trans_b = BooleanToTranspose(tb);
+#if USE_MKL_BLAS == 1
+    std::vector<const double*> A_array(batch_size);
+    std::vector<const double*> B_array(batch_size);
+    std::vector<double*> C_array(batch_size);
+    for (int i = 0; i < batch_size; ++i) {
+      A_array[i] = A + i * a_stride;
+      B_array[i] = B + i * b_stride;
+      C_array[i] = C + i * c_stride;
+    }
+    cblas_dgemm_batch(CblasColMajor, &trans_a, &trans_b, &M, &N, &K, &alpha, A_array.data(), &lda,
+                      B_array.data(), &ldb, &beta, C_array.data(), &ldc, 1, &batch_size);
+#else
+    for (int i = 0; i < batch_size; ++i) {
+      cblas_dgemm(CblasColMajor, trans_a, trans_b, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
+      A += a_stride;
+      B += b_stride;
+      C += c_stride;
+    }
+#endif
+  }
+};
+
+struct CblasDgemmBatchIterativeOp {
+  typedef double TDatatype;
+  void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, double alpha, double* A,
+                  int a_stride, int lda, double* B, int b_stride, int ldb, double beta, double* C,
+                  int c_stride, int ldc) {
+    CBLAS_TRANSPOSE trans_a = BooleanToTranspose(ta);
+    CBLAS_TRANSPOSE trans_b = BooleanToTranspose(tb);
+    for (int i = 0; i < batch_size; ++i) {
+      cblas_dgemm(CblasColMajor, trans_a, trans_b, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
+      A += a_stride;
+      B += b_stride;
+      C += c_stride;
+    }
+  }
+};
 
 // matrix multiplication for row major
 TVM_REGISTER_GLOBAL("tvm.contrib.cblas.matmul")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
-    DLTensor* A = args[0];
-    CHECK(TypeMatch(A->dtype, kDLFloat, 32) ||
-          TypeMatch(A->dtype, kDLFloat, 64));
+.set_body([](TVMArgs args, TVMRetValue* ret) {
+  DLTensor* A = args[0];
+  CHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64));
+
+  if (TypeMatch(A->dtype, kDLFloat, 32))
+    CallGemm(args, ret, CblasSgemmOp());
+  else
+    CallGemm(args, ret, CblasDgemmOp());
+});
 
-    if (TypeMatch(A->dtype, kDLFloat, 32))
-      CallGemm(args, ret, CblasSgemmOp());
-    else
-      CallGemm(args, ret, CblasDgemmOp());
-  });
+TVM_REGISTER_GLOBAL("tvm.contrib.cblas.batch_matmul")
+.set_body([](TVMArgs args, TVMRetValue* ret) {
+  DLTensor* A = args[0];
+  CHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64));
+  if (TypeMatch(A->dtype, kDLFloat, 32)) {
+    CallBatchGemm(args, ret, CblasSgemmBatchOp());
+  } else {
+    CallBatchGemm(args, ret, CblasDgemmBatchOp());
+  }
+});
+
+TVM_REGISTER_GLOBAL("tvm.contrib.cblas.batch_matmul_iterative")
+.set_body([](TVMArgs args, TVMRetValue* ret) {
+  DLTensor* A = args[0];
+  CHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64));
+  if (TypeMatch(A->dtype, kDLFloat, 32)) {
+    CallBatchGemm(args, ret, CblasSgemmBatchIterativeOp());
+  } else {
+    CallBatchGemm(args, ret, CblasDgemmBatchIterativeOp());
+  }
+});
 }  // namespace contrib
 }  // namespace tvm
index fe38b2a..2bcefb2 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  * \file tvm/contrib/gemm.h
  * \brief Shared implementation of gemm
  */
-#ifndef TVM_CONTRIB_CBLAS_GEMM_COMMON_H_
-#define TVM_CONTRIB_CBLAS_GEMM_COMMON_H_
+#pragma once
+
+#include <tvm/runtime/registry.h>
+#include <tvm/runtime/util.h>
 #include <algorithm>
 
 namespace tvm {
 namespace contrib {
 
 using namespace runtime;
-
-inline int ColumnStride(DLTensor* tensor) {
+inline int ColumnStride(DLTensor *tensor) {
   // If the tensor itself is transposed then it will have strides
   // backward from what we expect.  Regardless, the max of the strides
   // (the other stride is 1) is the column stride.
@@ -42,8 +43,7 @@ inline int ColumnStride(DLTensor* tensor) {
   }
 }
 
-
-inline int ElementStride(DLTensor* tensor) {
+inline int ElementStride(DLTensor *tensor) {
   if (tensor->strides) {
     return std::min(tensor->strides[0], tensor->strides[1]);
   } else {
@@ -51,29 +51,26 @@ inline int ElementStride(DLTensor* tensor) {
   }
 }
 
-
 // Reversed strides indicates an in-place transpose operation.
-inline bool IsInPlaceTransposed(DLTensortensor) {
+inline bool IsInPlaceTransposed(DLTensor *tensor) {
   return tensor->strides && (tensor->strides[1] > tensor->strides[0]);
 }
 
-
-inline int RowCount(DLTensor* tensor, bool trans) {
+inline int RowCount(DLTensor *tensor, bool trans) {
   return tensor->shape[trans ? 1 : 0];
 }
 
-
-inline int ColumnCount(DLTensor* tensor, bool trans) {
+inline int ColumnCount(DLTensor *tensor, bool trans) {
   return tensor->shape[trans ? 0 : 1];
 }
 
 // Call a column major blas.  Note that data is stored in tvm as row
 // major, so this we switch the arguments.
-template<typename TGemmOp>
+template <typename TGemmOp>
 inline void CallGemm(TVMArgs args, TVMRetValue *ret, TGemmOp op) {
-  DLTensorA = args[0];
-  DLTensorB = args[1];
-  DLTensorC = args[2];
+  DLTensor *A = args[0];
+  DLTensor *B = args[1];
+  DLTensor *C = args[2];
   bool transa = args[3];
   bool transb = args[4];
   int bit_depth = sizeof(typename TGemmOp::TDatatype) * 8;
@@ -96,25 +93,88 @@ inline void CallGemm(TVMArgs args, TVMRetValue *ret, TGemmOp op) {
   CHECK(TypeMatch(C->dtype, kDLFloat, bit_depth));
   double alpha = args.size() > 5 ? args[5] : 1.0;
   double beta = args.size() > 6 ? args[6] : 0.0;
-  op(transb,
-     transa,
-     ColumnCount(B, transb),
-     RowCount(A, transa),
-     ColumnCount(A, transa),
-     static_cast<float>(alpha),
-     reinterpret_cast<typename TGemmOp::TDatatype*>(static_cast<char*>(B->data)
-                                                    + B->byte_offset),
+  op(transb, transa, ColumnCount(B, transb), RowCount(A, transa),
+     ColumnCount(A, transa), static_cast<float>(alpha),
+     reinterpret_cast<typename TGemmOp::TDatatype *>(
+         static_cast<char *>(B->data) + B->byte_offset),
      ColumnStride(B),
-     reinterpret_cast<typename TGemmOp::TDatatype*>(static_cast<char*>(A->data)
-                                                    + A->byte_offset),
-     ColumnStride(A),
-     static_cast<float>(beta),
-     reinterpret_cast<typename TGemmOp::TDatatype*>(static_cast<char*>(C->data)
-                                                    + C->byte_offset),
+     reinterpret_cast<typename TGemmOp::TDatatype *>(
+         static_cast<char *>(A->data) + A->byte_offset),
+     ColumnStride(A), static_cast<float>(beta),
+     reinterpret_cast<typename TGemmOp::TDatatype *>(
+         static_cast<char *>(C->data) + C->byte_offset),
      ColumnStride(C));
 }
 
+inline int ColumnStride3D(DLTensor *tensor) {
+  // If the tensor itself is transposed then it will have strides
+  // backward from what we expect.  Regardless, the max of the strides
+  // (the other stride is 1) is the column stride.
+  if (tensor->strides) {
+    return std::max(tensor->strides[1], tensor->strides[2]);
+  } else {
+    return tensor->shape[2];
+  }
+}
+inline int ElementStride3D(DLTensor *tensor) {
+  if (tensor->strides) {
+    return std::min(tensor->strides[1], tensor->strides[2]);
+  } else {
+    return 1;
+  }
+}
+// Reversed strides indicates an in-place transpose operation.
+inline bool IsInPlaceTransposed3D(DLTensor *tensor) {
+  return tensor->strides && (tensor->strides[2] > tensor->strides[1]);
+}
+inline int BatchCount3D(DLTensor *tensor) { return tensor->shape[0]; }
+inline int RowCount3D(DLTensor *tensor, bool trans) {
+  return tensor->shape[trans ? 2 : 1];
+}
+inline int ColumnCount3D(DLTensor *tensor, bool trans) {
+  return tensor->shape[trans ? 1 : 2];
+}
+template <typename TBatchGemmOp>
+inline void CallBatchGemm(TVMArgs args, TVMRetValue *ret, TBatchGemmOp op) {
+  using DType = typename TBatchGemmOp::TDatatype;
+  DLTensor *A = args[0];
+  DLTensor *B = args[1];
+  DLTensor *C = args[2];
+  bool transa = args[3];
+  bool transb = args[4];
+  int bit_depth = sizeof(DType) * 8;
+  CHECK_EQ(A->ndim, 3);
+  CHECK_EQ(B->ndim, 3);
+  CHECK_EQ(C->ndim, 3);
+  int batch_size = BatchCount3D(A);
+  CHECK_EQ(BatchCount3D(B), batch_size);
+  CHECK_EQ(BatchCount3D(C), batch_size);
+  CHECK_EQ(ElementStride(A), 1);
+  CHECK_EQ(ElementStride(B), 1);
+  CHECK_EQ(ElementStride(C), 1);
+  // C can never be transposed.
+  CHECK(!IsInPlaceTransposed3D(C));
+  // Reversed strides indicates an in-place transpose operation.
+  transa = IsInPlaceTransposed3D(A) ? !transa : transa;
+  transb = IsInPlaceTransposed3D(B) ? !transb : transb;
+  CHECK(TypeMatch(B->dtype, kDLFloat, bit_depth));
+  CHECK(TypeMatch(C->dtype, kDLFloat, bit_depth));
+  double alpha = args.size() > 5 ? args[5] : 1.0;
+  double beta = args.size() > 6 ? args[6] : 0.0;
+  const int A_size = A->shape[1] * A->shape[2];
+  const int B_size = B->shape[1] * B->shape[2];
+  const int C_size = C->shape[1] * C->shape[2];
+  DType *A_data = reinterpret_cast<typename TBatchGemmOp::TDatatype *>(
+      static_cast<char *>(A->data) + A->byte_offset);
+  DType *B_data = reinterpret_cast<typename TBatchGemmOp::TDatatype *>(
+      static_cast<char *>(B->data) + B->byte_offset);
+  DType *C_data = reinterpret_cast<typename TBatchGemmOp::TDatatype *>(
+      static_cast<char *>(C->data) + C->byte_offset);
+  op(batch_size, transb, transa, ColumnCount3D(B, transb),
+     RowCount3D(A, transa), ColumnCount3D(A, transa), static_cast<float>(alpha),
+     B_data, B_size, ColumnStride3D(B), A_data, A_size, ColumnStride3D(A),
+     static_cast<float>(beta), C_data, C_size, ColumnStride3D(C));
+}
+
 }  // namespace contrib
 }  // namespace tvm
-
-#endif  // TVM_CONTRIB_CBLAS_GEMM_COMMON_H_
index 6705328..808c07a 100644 (file)
 # under the License.
 import tvm
 import numpy as np
+import topi.testing
 from tvm.contrib import cblas
 
-def test_matmul_add():
-    n = 1024
-    l = 128
-    m = 235
-    bias = tvm.var('bias', dtype=tvm.float32)
-    A = tvm.placeholder((n, l), name='A')
-    B = tvm.placeholder((l, m), name='B')
-    C = cblas.matmul(A, B)
+def verify_matmul_add(m, l, n, transa=False, transb=False, dtype=tvm.float32):
+    bias = tvm.var('bias', dtype=dtype)
+    ashape = (l, n) if transa else (n, l)
+    bshape = (m, l) if transb else (l, m)
+    A = tvm.placeholder(ashape, name='A', dtype=dtype)
+    B = tvm.placeholder(bshape, name='B', dtype=dtype)
+    C = cblas.matmul(A, B, transa, transb)
     D = tvm.compute(C.shape, lambda i, j: C[i,j] + bias, name="D")
     s = tvm.create_schedule(D.op)
 
+    def get_numpy(a, b, bb, transa, transb):
+        if transa:
+            a = a.transpose()
+        if transb:
+            b = b.transpose()
+        return np.dot(a, b) + bb
+
     def verify(target="llvm"):
         if not tvm.module.enabled(target):
             print("skip because %s is not enabled..." % target)
@@ -38,15 +45,69 @@ def test_matmul_add():
             return
         ctx = tvm.cpu(0)
         f = tvm.build(s, [A, B, D, bias], target)
-        a = tvm.nd.array(np.random.uniform(size=(n, l)).astype(A.dtype), ctx)
-        b = tvm.nd.array(np.random.uniform(size=(l, m)).astype(B.dtype), ctx)
+        a = tvm.nd.array(np.random.uniform(size=ashape).astype(A.dtype), ctx)
+        b = tvm.nd.array(np.random.uniform(size=bshape).astype(B.dtype), ctx)
         d = tvm.nd.array(np.zeros((n, m), dtype=D.dtype), ctx)
         bb = 10.0
         f(a, b, d, bb)
         tvm.testing.assert_allclose(
-            d.asnumpy(), np.dot(a.asnumpy(), b.asnumpy()) + bb, rtol=1e-5)
+            d.asnumpy(), get_numpy(a.asnumpy(), b.asnumpy(), bb, transa, transb), rtol=1e-5)
+    verify()
+
+def test_matmul_add():
+    verify_matmul_add(235, 128, 1024)
+    verify_matmul_add(235, 128, 1024, True, False)
+    verify_matmul_add(235, 128, 1024, False, True)
+    verify_matmul_add(235, 128, 1024, True, True)
+    verify_matmul_add(1, 16, 4)
+    verify_matmul_add(1, 16, 3, True, False)
+    verify_matmul_add(1, 16, 3, False, False)
+    verify_matmul_add(1, 16, 3, True, True)
+
+def verify_batch_matmul(batch, m, l, n, transa=False, transb=False, iterative=False, dtype=tvm.float32):
+    ashape = (batch, l, n) if transa else (batch, n, l)
+    bshape = (batch, m, l) if transb else (batch, l, m)
+    A = tvm.placeholder(ashape, name='A', dtype=dtype)
+    B = tvm.placeholder(bshape, name='B', dtype=dtype)
+    C = cblas.batch_matmul(A, B, transa, transb)
+    D = tvm.compute(C.shape, lambda k, i, j: C[k, i,j], name="D")
+    s = tvm.create_schedule(D.op)
+
+    def get_numpy(a, b, transa, transb):
+        if transa:
+            a = a.transpose(0, 2, 1)
+        if not transb:
+            b = b.transpose(0, 2, 1)
+        return topi.testing.batch_matmul(a, b)
+
+    def verify(target="llvm"):
+        if not tvm.module.enabled(target):
+            print("skip because %s is not enabled..." % target)
+            return
+        if not tvm.get_global_func("tvm.contrib.cblas.matmul", True):
+            print("skip because extern function is not available")
+            return
+        ctx = tvm.cpu(0)
+        f = tvm.build(s, [A, B, D], target)
+        a = tvm.nd.array(np.random.uniform(size=ashape).astype(A.dtype), ctx)
+        b = tvm.nd.array(np.random.uniform(size=bshape).astype(B.dtype), ctx)
+        d = tvm.nd.array(np.zeros((batch, n, m), dtype=D.dtype), ctx)
+        f(a, b, d)
+        tvm.testing.assert_allclose(
+            d.asnumpy(), get_numpy(a.asnumpy(), b.asnumpy(), transa, transb), rtol=1e-5)
     verify()
 
+def test_batch_matmul():
+    verify_batch_matmul(16, 235, 128, 1024)
+    verify_batch_matmul(16, 235, 128, 1024, True, False)
+    verify_batch_matmul(16, 235, 128, 1024, False, True)
+    verify_batch_matmul(16, 235, 128, 1024, True, True)
+    verify_batch_matmul(1, 1, 16, 3)
+    verify_batch_matmul(1, 1, 16, 3, True, False)
+    verify_batch_matmul(1, 1, 16, 3, False, False)
+    verify_batch_matmul(1, 1, 16, 3, True, True)
+    verify_batch_matmul(1, 1, 16, 3, iterative=True)
 
 if __name__ == "__main__":
     test_matmul_add()
+    test_batch_matmul()