if(NOT IS_DIRECTORY ${USE_MKL_PATH})
set(USE_MKL_PATH /opt/intel/mkl)
endif()
- find_library(BLAS_LIBRARY NAMES mkl_rt mklml_gnu HINTS ${USE_MKL_PATH}/lib/ ${USE_MKL_PATH}/lib/intel64)
+ if(APPLE)
+ find_library(BLAS_LIBRARY NAMES mklml HINTS ${USE_MKL_PATH}/lib/ ${USE_MKL_PATH}/lib/intel64)
+ elseif(UNIX)
+ find_library(BLAS_LIBRARY NAMES mkl_rt mklml_gnu HINTS ${USE_MKL_PATH}/lib/ ${USE_MKL_PATH}/lib/intel64)
+ endif()
include_directories(${USE_MKL_PATH}/include)
list(APPEND TVM_RUNTIME_LINKER_LIBS ${BLAS_LIBRARY})
list(APPEND RUNTIME_SRCS ${CBLAS_CONTRIB_SRC})
"""External function interface to BLAS libraries."""
from __future__ import absolute_import as _abs
-from .. import api as _api
-from .. import intrin as _intrin
+from .. import api as _api, intrin as _intrin
-def matmul(lhs, rhs, transa=False, transb=False):
+
+def matmul(lhs, rhs, transa=False, transb=False, **kwargs):
"""Create an extern op that compute matrix mult of A and rhs with CrhsLAS
This function serves as an example on how to call external libraries.
n = lhs.shape[1] if transa else lhs.shape[0]
m = rhs.shape[0] if transb else rhs.shape[1]
return _api.extern(
- (n, m), [lhs, rhs],
+ (n, m),
+ [lhs, rhs],
+ lambda ins, outs: _intrin.call_packed(
+ "tvm.contrib.cblas.matmul", ins[0], ins[1], outs[0], transa, transb
+ ),
+ name="C",
+ **kwargs
+ )
+
+
+def batch_matmul(lhs, rhs, transa=False, transb=False, iterative=False, **kwargs):
+ """Create an extern op that compute batched matrix mult of A and rhs with CBLAS
+ This function serves as an example on how to call external libraries.
+ Parameters
+ ----------
+ lhs : Tensor
+ The left matrix operand
+ rhs : Tensor
+ The right matrix operand
+ transa : bool
+ Whether transpose lhs
+ transb : bool
+ Whether transpose rhs
+ Returns
+ -------
+ C : Tensor
+ The result tensor.
+ """
+ b = lhs.shape[0]
+ n = lhs.shape[2] if transa else lhs.shape[1]
+ m = rhs.shape[1] if transb else rhs.shape[2]
+ return _api.extern(
+ (b, n, m),
+ [lhs, rhs],
lambda ins, outs: _intrin.call_packed(
- "tvm.contrib.cblas.matmul",
- ins[0], ins[1], outs[0], transa, transb), name="C")
+ "tvm.contrib.cblas.batch_matmul"
+ if not iterative
+ else "tvm.contrib.cblas.batch_matmul_iterative",
+ ins[0],
+ ins[1],
+ outs[0],
+ transa,
+ transb,
+ ),
+ name="C",
+ **kwargs
+ )
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* Copyright (c) 2017 by Contributors
* \file Use external cblas library call.
*/
+#include <dmlc/logging.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/util.h>
-#include <dmlc/logging.h>
#include "gemm_common.h"
-
extern "C" {
#if USE_MKL_BLAS == 1
#include <mkl_cblas.h>
using namespace runtime;
-inline CBLAS_TRANSPOSE BooleanToTranspose(bool trans) {
- return trans ? CblasTrans : CblasNoTrans;
-}
+inline CBLAS_TRANSPOSE BooleanToTranspose(bool trans) { return trans ? CblasTrans : CblasNoTrans; }
struct CblasSgemmOp {
typedef float TDatatype;
- void operator()(bool ta, bool tb,
- int M, int N, int K,
- float alpha, float* A, int lda,
- float* B, int ldb,
- float beta, float* C, int ldc) {
- cblas_sgemm(CblasColMajor,
- BooleanToTranspose(ta),
- BooleanToTranspose(tb),
- M, N, K,
- alpha, A, lda,
- B, ldb,
- beta, C, ldc);
+ void operator()(bool ta, bool tb, int M, int N, int K, float alpha, float* A, int lda, float* B,
+ int ldb, float beta, float* C, int ldc) {
+ cblas_sgemm(CblasColMajor, BooleanToTranspose(ta), BooleanToTranspose(tb), M, N, K, alpha, A,
+ lda, B, ldb, beta, C, ldc);
}
};
struct CblasDgemmOp {
typedef double TDatatype;
- void operator()(bool ta, bool tb,
- int M, int N, int K,
- double alpha, double* A, int lda,
- double* B, int ldb,
- double beta, double* C, int ldc) {
- cblas_dgemm(CblasColMajor,
- BooleanToTranspose(ta),
- BooleanToTranspose(tb),
- M, N, K,
- alpha, A, lda,
- B, ldb,
- beta, C, ldc);
+ void operator()(bool ta, bool tb, int M, int N, int K, double alpha, double* A, int lda,
+ double* B, int ldb, double beta, double* C, int ldc) {
+ cblas_dgemm(CblasColMajor, BooleanToTranspose(ta), BooleanToTranspose(tb), M, N, K, alpha, A,
+ lda, B, ldb, beta, C, ldc);
}
};
+struct CblasSgemmBatchOp {
+ typedef float TDatatype;
+ void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, float alpha, float* A,
+ int a_stride, int lda, float* B, int b_stride, int ldb, float beta, float* C,
+ int c_stride, int ldc) {
+ CBLAS_TRANSPOSE trans_a = BooleanToTranspose(ta);
+ CBLAS_TRANSPOSE trans_b = BooleanToTranspose(tb);
+#if USE_MKL_BLAS == 1
+ std::vector<const float*> A_array(batch_size);
+ std::vector<const float*> B_array(batch_size);
+ std::vector<float*> C_array(batch_size);
+ for (int i = 0; i < batch_size; ++i) {
+ A_array[i] = A + i * a_stride;
+ B_array[i] = B + i * b_stride;
+ C_array[i] = C + i * c_stride;
+ }
+ cblas_sgemm_batch(CblasColMajor, &trans_a, &trans_b, &M, &N, &K, &alpha, A_array.data(), &lda,
+ B_array.data(), &ldb, &beta, C_array.data(), &ldc, 1, &batch_size);
+#else
+ for (int i = 0; i < batch_size; ++i) {
+ cblas_sgemm(CblasColMajor, trans_a, trans_b, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
+ A += a_stride;
+ B += b_stride;
+ C += c_stride;
+ }
+#endif
+ }
+};
+
+struct CblasSgemmBatchIterativeOp {
+ typedef float TDatatype;
+ void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, float alpha, float* A,
+ int a_stride, int lda, float* B, int b_stride, int ldb, float beta, float* C,
+ int c_stride, int ldc) {
+ CBLAS_TRANSPOSE trans_a = BooleanToTranspose(ta);
+ CBLAS_TRANSPOSE trans_b = BooleanToTranspose(tb);
+ for (int i = 0; i < batch_size; ++i) {
+ cblas_sgemm(CblasColMajor, trans_a, trans_b, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
+ A += a_stride;
+ B += b_stride;
+ C += c_stride;
+ }
+ }
+};
+
+struct CblasDgemmBatchOp {
+ typedef double TDatatype;
+ void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, double alpha, double* A,
+ int a_stride, int lda, double* B, int b_stride, int ldb, double beta, double* C,
+ int c_stride, int ldc) {
+ CBLAS_TRANSPOSE trans_a = BooleanToTranspose(ta);
+ CBLAS_TRANSPOSE trans_b = BooleanToTranspose(tb);
+#if USE_MKL_BLAS == 1
+ std::vector<const double*> A_array(batch_size);
+ std::vector<const double*> B_array(batch_size);
+ std::vector<double*> C_array(batch_size);
+ for (int i = 0; i < batch_size; ++i) {
+ A_array[i] = A + i * a_stride;
+ B_array[i] = B + i * b_stride;
+ C_array[i] = C + i * c_stride;
+ }
+ cblas_dgemm_batch(CblasColMajor, &trans_a, &trans_b, &M, &N, &K, &alpha, A_array.data(), &lda,
+ B_array.data(), &ldb, &beta, C_array.data(), &ldc, 1, &batch_size);
+#else
+ for (int i = 0; i < batch_size; ++i) {
+ cblas_dgemm(CblasColMajor, trans_a, trans_b, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
+ A += a_stride;
+ B += b_stride;
+ C += c_stride;
+ }
+#endif
+ }
+};
+
+struct CblasDgemmBatchIterativeOp {
+ typedef double TDatatype;
+ void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, double alpha, double* A,
+ int a_stride, int lda, double* B, int b_stride, int ldb, double beta, double* C,
+ int c_stride, int ldc) {
+ CBLAS_TRANSPOSE trans_a = BooleanToTranspose(ta);
+ CBLAS_TRANSPOSE trans_b = BooleanToTranspose(tb);
+ for (int i = 0; i < batch_size; ++i) {
+ cblas_dgemm(CblasColMajor, trans_a, trans_b, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
+ A += a_stride;
+ B += b_stride;
+ C += c_stride;
+ }
+ }
+};
// matrix multiplication for row major
TVM_REGISTER_GLOBAL("tvm.contrib.cblas.matmul")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- DLTensor* A = args[0];
- CHECK(TypeMatch(A->dtype, kDLFloat, 32) ||
- TypeMatch(A->dtype, kDLFloat, 64));
+.set_body([](TVMArgs args, TVMRetValue* ret) {
+ DLTensor* A = args[0];
+ CHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64));
+
+ if (TypeMatch(A->dtype, kDLFloat, 32))
+ CallGemm(args, ret, CblasSgemmOp());
+ else
+ CallGemm(args, ret, CblasDgemmOp());
+});
- if (TypeMatch(A->dtype, kDLFloat, 32))
- CallGemm(args, ret, CblasSgemmOp());
- else
- CallGemm(args, ret, CblasDgemmOp());
- });
+TVM_REGISTER_GLOBAL("tvm.contrib.cblas.batch_matmul")
+.set_body([](TVMArgs args, TVMRetValue* ret) {
+ DLTensor* A = args[0];
+ CHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64));
+ if (TypeMatch(A->dtype, kDLFloat, 32)) {
+ CallBatchGemm(args, ret, CblasSgemmBatchOp());
+ } else {
+ CallBatchGemm(args, ret, CblasDgemmBatchOp());
+ }
+});
+
+TVM_REGISTER_GLOBAL("tvm.contrib.cblas.batch_matmul_iterative")
+.set_body([](TVMArgs args, TVMRetValue* ret) {
+ DLTensor* A = args[0];
+ CHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64));
+ if (TypeMatch(A->dtype, kDLFloat, 32)) {
+ CallBatchGemm(args, ret, CblasSgemmBatchIterativeOp());
+ } else {
+ CallBatchGemm(args, ret, CblasDgemmBatchIterativeOp());
+ }
+});
} // namespace contrib
} // namespace tvm
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* \file tvm/contrib/gemm.h
* \brief Shared implementation of gemm
*/
-#ifndef TVM_CONTRIB_CBLAS_GEMM_COMMON_H_
-#define TVM_CONTRIB_CBLAS_GEMM_COMMON_H_
+#pragma once
+
+#include <tvm/runtime/registry.h>
+#include <tvm/runtime/util.h>
#include <algorithm>
namespace tvm {
namespace contrib {
using namespace runtime;
-
-inline int ColumnStride(DLTensor* tensor) {
+inline int ColumnStride(DLTensor *tensor) {
// If the tensor itself is transposed then it will have strides
// backward from what we expect. Regardless, the max of the strides
// (the other stride is 1) is the column stride.
}
}
-
-inline int ElementStride(DLTensor* tensor) {
+inline int ElementStride(DLTensor *tensor) {
if (tensor->strides) {
return std::min(tensor->strides[0], tensor->strides[1]);
} else {
}
}
-
// Reversed strides indicates an in-place transpose operation.
-inline bool IsInPlaceTransposed(DLTensor* tensor) {
+inline bool IsInPlaceTransposed(DLTensor *tensor) {
return tensor->strides && (tensor->strides[1] > tensor->strides[0]);
}
-
-inline int RowCount(DLTensor* tensor, bool trans) {
+inline int RowCount(DLTensor *tensor, bool trans) {
return tensor->shape[trans ? 1 : 0];
}
-
-inline int ColumnCount(DLTensor* tensor, bool trans) {
+inline int ColumnCount(DLTensor *tensor, bool trans) {
return tensor->shape[trans ? 0 : 1];
}
// Call a column major blas. Note that data is stored in tvm as row
// major, so this we switch the arguments.
-template<typename TGemmOp>
+template <typename TGemmOp>
inline void CallGemm(TVMArgs args, TVMRetValue *ret, TGemmOp op) {
- DLTensor* A = args[0];
- DLTensor* B = args[1];
- DLTensor* C = args[2];
+ DLTensor *A = args[0];
+ DLTensor *B = args[1];
+ DLTensor *C = args[2];
bool transa = args[3];
bool transb = args[4];
int bit_depth = sizeof(typename TGemmOp::TDatatype) * 8;
CHECK(TypeMatch(C->dtype, kDLFloat, bit_depth));
double alpha = args.size() > 5 ? args[5] : 1.0;
double beta = args.size() > 6 ? args[6] : 0.0;
- op(transb,
- transa,
- ColumnCount(B, transb),
- RowCount(A, transa),
- ColumnCount(A, transa),
- static_cast<float>(alpha),
- reinterpret_cast<typename TGemmOp::TDatatype*>(static_cast<char*>(B->data)
- + B->byte_offset),
+ op(transb, transa, ColumnCount(B, transb), RowCount(A, transa),
+ ColumnCount(A, transa), static_cast<float>(alpha),
+ reinterpret_cast<typename TGemmOp::TDatatype *>(
+ static_cast<char *>(B->data) + B->byte_offset),
ColumnStride(B),
- reinterpret_cast<typename TGemmOp::TDatatype*>(static_cast<char*>(A->data)
- + A->byte_offset),
- ColumnStride(A),
- static_cast<float>(beta),
- reinterpret_cast<typename TGemmOp::TDatatype*>(static_cast<char*>(C->data)
- + C->byte_offset),
+ reinterpret_cast<typename TGemmOp::TDatatype *>(
+ static_cast<char *>(A->data) + A->byte_offset),
+ ColumnStride(A), static_cast<float>(beta),
+ reinterpret_cast<typename TGemmOp::TDatatype *>(
+ static_cast<char *>(C->data) + C->byte_offset),
ColumnStride(C));
}
+inline int ColumnStride3D(DLTensor *tensor) {
+ // If the tensor itself is transposed then it will have strides
+ // backward from what we expect. Regardless, the max of the strides
+ // (the other stride is 1) is the column stride.
+ if (tensor->strides) {
+ return std::max(tensor->strides[1], tensor->strides[2]);
+ } else {
+ return tensor->shape[2];
+ }
+}
+inline int ElementStride3D(DLTensor *tensor) {
+ if (tensor->strides) {
+ return std::min(tensor->strides[1], tensor->strides[2]);
+ } else {
+ return 1;
+ }
+}
+// Reversed strides indicates an in-place transpose operation.
+inline bool IsInPlaceTransposed3D(DLTensor *tensor) {
+ return tensor->strides && (tensor->strides[2] > tensor->strides[1]);
+}
+inline int BatchCount3D(DLTensor *tensor) { return tensor->shape[0]; }
+inline int RowCount3D(DLTensor *tensor, bool trans) {
+ return tensor->shape[trans ? 2 : 1];
+}
+inline int ColumnCount3D(DLTensor *tensor, bool trans) {
+ return tensor->shape[trans ? 1 : 2];
+}
+template <typename TBatchGemmOp>
+inline void CallBatchGemm(TVMArgs args, TVMRetValue *ret, TBatchGemmOp op) {
+ using DType = typename TBatchGemmOp::TDatatype;
+ DLTensor *A = args[0];
+ DLTensor *B = args[1];
+ DLTensor *C = args[2];
+ bool transa = args[3];
+ bool transb = args[4];
+ int bit_depth = sizeof(DType) * 8;
+ CHECK_EQ(A->ndim, 3);
+ CHECK_EQ(B->ndim, 3);
+ CHECK_EQ(C->ndim, 3);
+ int batch_size = BatchCount3D(A);
+ CHECK_EQ(BatchCount3D(B), batch_size);
+ CHECK_EQ(BatchCount3D(C), batch_size);
+ CHECK_EQ(ElementStride(A), 1);
+ CHECK_EQ(ElementStride(B), 1);
+ CHECK_EQ(ElementStride(C), 1);
+ // C can never be transposed.
+ CHECK(!IsInPlaceTransposed3D(C));
+ // Reversed strides indicates an in-place transpose operation.
+ transa = IsInPlaceTransposed3D(A) ? !transa : transa;
+ transb = IsInPlaceTransposed3D(B) ? !transb : transb;
+ CHECK(TypeMatch(B->dtype, kDLFloat, bit_depth));
+ CHECK(TypeMatch(C->dtype, kDLFloat, bit_depth));
+ double alpha = args.size() > 5 ? args[5] : 1.0;
+ double beta = args.size() > 6 ? args[6] : 0.0;
+ const int A_size = A->shape[1] * A->shape[2];
+ const int B_size = B->shape[1] * B->shape[2];
+ const int C_size = C->shape[1] * C->shape[2];
+ DType *A_data = reinterpret_cast<typename TBatchGemmOp::TDatatype *>(
+ static_cast<char *>(A->data) + A->byte_offset);
+ DType *B_data = reinterpret_cast<typename TBatchGemmOp::TDatatype *>(
+ static_cast<char *>(B->data) + B->byte_offset);
+ DType *C_data = reinterpret_cast<typename TBatchGemmOp::TDatatype *>(
+ static_cast<char *>(C->data) + C->byte_offset);
+ op(batch_size, transb, transa, ColumnCount3D(B, transb),
+ RowCount3D(A, transa), ColumnCount3D(A, transa), static_cast<float>(alpha),
+ B_data, B_size, ColumnStride3D(B), A_data, A_size, ColumnStride3D(A),
+ static_cast<float>(beta), C_data, C_size, ColumnStride3D(C));
+}
+
} // namespace contrib
} // namespace tvm
-
-#endif // TVM_CONTRIB_CBLAS_GEMM_COMMON_H_
# under the License.
import tvm
import numpy as np
+import topi.testing
from tvm.contrib import cblas
-def test_matmul_add():
- n = 1024
- l = 128
- m = 235
- bias = tvm.var('bias', dtype=tvm.float32)
- A = tvm.placeholder((n, l), name='A')
- B = tvm.placeholder((l, m), name='B')
- C = cblas.matmul(A, B)
+def verify_matmul_add(m, l, n, transa=False, transb=False, dtype=tvm.float32):
+ bias = tvm.var('bias', dtype=dtype)
+ ashape = (l, n) if transa else (n, l)
+ bshape = (m, l) if transb else (l, m)
+ A = tvm.placeholder(ashape, name='A', dtype=dtype)
+ B = tvm.placeholder(bshape, name='B', dtype=dtype)
+ C = cblas.matmul(A, B, transa, transb)
D = tvm.compute(C.shape, lambda i, j: C[i,j] + bias, name="D")
s = tvm.create_schedule(D.op)
+ def get_numpy(a, b, bb, transa, transb):
+ if transa:
+ a = a.transpose()
+ if transb:
+ b = b.transpose()
+ return np.dot(a, b) + bb
+
def verify(target="llvm"):
if not tvm.module.enabled(target):
print("skip because %s is not enabled..." % target)
return
ctx = tvm.cpu(0)
f = tvm.build(s, [A, B, D, bias], target)
- a = tvm.nd.array(np.random.uniform(size=(n, l)).astype(A.dtype), ctx)
- b = tvm.nd.array(np.random.uniform(size=(l, m)).astype(B.dtype), ctx)
+ a = tvm.nd.array(np.random.uniform(size=ashape).astype(A.dtype), ctx)
+ b = tvm.nd.array(np.random.uniform(size=bshape).astype(B.dtype), ctx)
d = tvm.nd.array(np.zeros((n, m), dtype=D.dtype), ctx)
bb = 10.0
f(a, b, d, bb)
tvm.testing.assert_allclose(
- d.asnumpy(), np.dot(a.asnumpy(), b.asnumpy()) + bb, rtol=1e-5)
+ d.asnumpy(), get_numpy(a.asnumpy(), b.asnumpy(), bb, transa, transb), rtol=1e-5)
+ verify()
+
+def test_matmul_add():
+ verify_matmul_add(235, 128, 1024)
+ verify_matmul_add(235, 128, 1024, True, False)
+ verify_matmul_add(235, 128, 1024, False, True)
+ verify_matmul_add(235, 128, 1024, True, True)
+ verify_matmul_add(1, 16, 4)
+ verify_matmul_add(1, 16, 3, True, False)
+ verify_matmul_add(1, 16, 3, False, False)
+ verify_matmul_add(1, 16, 3, True, True)
+
+def verify_batch_matmul(batch, m, l, n, transa=False, transb=False, iterative=False, dtype=tvm.float32):
+ ashape = (batch, l, n) if transa else (batch, n, l)
+ bshape = (batch, m, l) if transb else (batch, l, m)
+ A = tvm.placeholder(ashape, name='A', dtype=dtype)
+ B = tvm.placeholder(bshape, name='B', dtype=dtype)
+ C = cblas.batch_matmul(A, B, transa, transb)
+ D = tvm.compute(C.shape, lambda k, i, j: C[k, i,j], name="D")
+ s = tvm.create_schedule(D.op)
+
+ def get_numpy(a, b, transa, transb):
+ if transa:
+ a = a.transpose(0, 2, 1)
+ if not transb:
+ b = b.transpose(0, 2, 1)
+ return topi.testing.batch_matmul(a, b)
+
+ def verify(target="llvm"):
+ if not tvm.module.enabled(target):
+ print("skip because %s is not enabled..." % target)
+ return
+ if not tvm.get_global_func("tvm.contrib.cblas.matmul", True):
+ print("skip because extern function is not available")
+ return
+ ctx = tvm.cpu(0)
+ f = tvm.build(s, [A, B, D], target)
+ a = tvm.nd.array(np.random.uniform(size=ashape).astype(A.dtype), ctx)
+ b = tvm.nd.array(np.random.uniform(size=bshape).astype(B.dtype), ctx)
+ d = tvm.nd.array(np.zeros((batch, n, m), dtype=D.dtype), ctx)
+ f(a, b, d)
+ tvm.testing.assert_allclose(
+ d.asnumpy(), get_numpy(a.asnumpy(), b.asnumpy(), transa, transb), rtol=1e-5)
verify()
+def test_batch_matmul():
+ verify_batch_matmul(16, 235, 128, 1024)
+ verify_batch_matmul(16, 235, 128, 1024, True, False)
+ verify_batch_matmul(16, 235, 128, 1024, False, True)
+ verify_batch_matmul(16, 235, 128, 1024, True, True)
+ verify_batch_matmul(1, 1, 16, 3)
+ verify_batch_matmul(1, 1, 16, 3, True, False)
+ verify_batch_matmul(1, 1, 16, 3, False, False)
+ verify_batch_matmul(1, 1, 16, 3, True, True)
+ verify_batch_matmul(1, 1, 16, 3, iterative=True)
if __name__ == "__main__":
test_matmul_add()
+ test_batch_matmul()