static inline void gemm_batched(const CBLAS_TRANSPOSE trans_A, const CBLAS_TRANSPOSE trans_B,
const int batch_size, const int M, const int N, const int K, const float alpha,
- const float** A, const float** B, const float beta, float** C) {
- const int lda = (trans_A == CblasNoTrans) ? K : M;
- const int ldb = (trans_B == CblasNoTrans) ? N : K;
- const int ldc = N;
+ const float** A, const int lda, const float** B, const int ldb, const float beta,
+ float** C, const int ldc) {
cblas_sgemm_batch(CblasRowMajor, &trans_A, &trans_B, &M, &N, &K, &alpha,
A, &lda, B, &ldb, &beta, C, &ldc, 1, &batch_size);
static inline void gemm_batched(const CBLAS_TRANSPOSE trans_A, const CBLAS_TRANSPOSE trans_B,
const int batch_size, const int M, const int N, const int K, const double alpha,
- const double** A, const double** B, const double beta, double** C) {
- const int lda = (trans_A == CblasNoTrans) ? K : M;
- const int ldb = (trans_B == CblasNoTrans) ? N : K;
- const int ldc = N;
+ const double** A, const int lda, const double** B, const int ldb, const double beta,
+ double** C, const int ldc) {
cblas_dgemm_batch(CblasRowMajor, &trans_A, &trans_B, &M, &N, &K, &alpha,
A, &lda, B, &ldb, &beta, C, &ldc, 1, &batch_size);
template <typename scalar_t>
static inline void baddbmm_mkl_template(const Tensor& res, const Tensor& mat1, const Tensor& mat2, Scalar beta_, Scalar alpha_) {
auto is_transposed = [&](const Tensor& t) {
- return t.stride(0) == 1 && t.stride(1) == t.size(0);
+ return t.stride(0) == 1 && t.stride(1) >= t.size(0);
};
const CBLAS_TRANSPOSE trans_A = is_transposed(mat1[0]) ? CblasTrans : CblasNoTrans;
const CBLAS_TRANSPOSE trans_B = is_transposed(mat2[0]) ? CblasTrans : CblasNoTrans;
scalar_t alpha = alpha_.to<scalar_t>();
scalar_t beta = beta_.to<scalar_t>();
+ const int lda = is_transposed(mat1[0]) ? mat1[0].stride(1) : mat1[0].stride(0);
+ const int ldb = is_transposed(mat2[0]) ? mat2[0].stride(1) : mat2[0].stride(0);
+ const int ldc = res[0].stride(0);
+
std::vector<const scalar_t*> A(batch_size);
std::vector<const scalar_t*> B(batch_size);
std::vector<scalar_t*> C(batch_size);
C[batch] = res[batch].data<scalar_t>();
}
- gemm_batched(trans_A, trans_B, batch_size, M, N, K, alpha, A.data(), B.data(), beta, C.data());
+ gemm_batched(trans_A, trans_B, batch_size, M, N, K, alpha, A.data(), lda, B.data(), ldb, beta, C.data(), ldc);
}
Tensor& _baddbmm_mkl_(Tensor& self, const Tensor& batch1, const Tensor& batch2, Scalar beta, Scalar alpha) {