* limitations under the License.
*******************************************************************************/
+#include "mkldnn_types.h"
+
#include "mkldnn_thread.hpp"
#include "nstl.hpp"
#include "utils.hpp"
-#include "../jit_generator.hpp"
+#include "jit_generator.hpp"
-#include "gemm_utils.hpp"
+#include "gemm_utils_f32.hpp"
+#include "ref_gemm_f32.hpp"
namespace mkldnn {
namespace impl {
using namespace mkldnn::impl::utils;
using namespace gemm_utils;
+namespace {
template <typename data_t>
-static void copy_A(
- bool isTransA, int K, const data_t *A, const int lda, data_t *ws) {
+void copy_A(
+ bool isTransA, int K, const data_t *A, const dim_t lda, data_t *ws) {
for (int k = 0; k < K; k++) {
PRAGMA_OMP_SIMD()
- for (int i = 0; i < gemm_utils::unroll_factor<data_t>::m; i++) {
+ for (int i = 0; i < unroll_factor<data_t>::m; i++) {
ws[i] = isTransA ? A[i * lda + k] : A[i + k * lda];
}
ws += unroll_factor<data_t>::m;
}
template <typename data_t, bool isTransA, bool isTransB>
-static void kernel_mxn(int K, const data_t *A, const int lda,
- const data_t *B, const int ldb, data_t *C, const int ldc,
+void kernel_mxn(int K, const data_t *A, const dim_t lda,
+ const data_t *B, const dim_t ldb, data_t *C, const dim_t ldc,
const data_t alpha, const data_t beta) {
data_t c[unroll_factor<data_t>::m * unroll_factor<data_t>::n] =
{ static_cast<data_t>(0.) };
}
template <typename data_t, bool isTransA, bool isTransB>
-static void block_ker(const int M, const int N, const int K,
- const data_t *A, const int lda, const data_t *B, const int ldb,
- data_t *C, const int ldc, const data_t alpha, const data_t beta,
+void block_ker(const int M, const int N, const int K,
+ const data_t *A, const dim_t lda, const data_t *B, const dim_t ldb,
+ data_t *C, const dim_t ldc, const data_t alpha, const data_t beta,
data_t *ws, bool do_copy) {
int Nu = rnd_dn(N, unroll_factor<data_t>::n);
int Mu = rnd_dn(M, unroll_factor<data_t>::m);
template <typename data_t, bool isTransA, bool isTransB>
void gemm_ithr(const int M, const int N, const int K, const data_t alpha,
- const data_t *A, const int lda, const data_t *B, const int ldb,
- const data_t beta, data_t *C, const int ldc, bool do_copy, data_t *ws) {
+ const data_t *A, const dim_t lda, const data_t *B, const dim_t ldb,
+ const data_t beta, data_t *C, const dim_t ldc, bool do_copy,
+ data_t *ws) {
constexpr int BM = gemm_traits<data_t, isTransA, isTransB>::BM;
constexpr int BN = gemm_traits<data_t, isTransA, isTransB>::BN;
constexpr int BK = gemm_traits<data_t, isTransA, isTransB>::BK;
return;
if ((K <= 0) || (alpha == static_cast<data_t>(0))) {
- ptrdiff_t MN = (ptrdiff_t)N * M;
+ dim_t MN = N * M;
if (beta == static_cast<data_t>(0.)) {
- for (ptrdiff_t j = 0; j < MN; j++)
+ for (dim_t j = 0; j < MN; j++)
C[j] = static_cast<data_t>(0.);
} else if (beta != static_cast<data_t>(1.)) {
- for (ptrdiff_t j = 0; j < MN; j++)
+ for (dim_t j = 0; j < MN; j++)
C[j] *= beta;
}
return;
}
}
+}
+
template <typename data_t>
-void ref_gemm(const char *transa_, const char *transb_, const int *M_,
+mkldnn_status_t ref_gemm(
+ const char *transa_, const char *transb_, const int *M_,
const int *N_, const int *K_, const data_t *alpha_, const data_t *A,
const int *lda_, const data_t *B, const int *ldb_, const data_t *beta_,
data_t *C, const int *ldc_, const data_t *bias) {
+
bool isTransA = (*transa_ == 'T' || *transa_ == 't');
bool isTransB = (*transb_ == 'T' || *transb_ == 't');
- const int M = *M_, N = *N_, K = *K_, lda = *lda_, ldb = *ldb_, ldc = *ldc_;
+ const int M = *M_, N = *N_, K = *K_;
+ const dim_t lda = *lda_, ldb = *ldb_, ldc = *ldc_;
const data_t alpha = *alpha_, beta = *beta_;
int max_nthr = mkldnn_in_parallel() ? 1 : mkldnn_get_max_threads();
int nthr_m, nthr_n, nthr_k;
int MB, NB, KB;
// thread balancing over M, N, K & size of blocking dimensions
- gemm_utils::calc_nthr_nocopy_avx(
+ calc_nthr_nocopy_avx(
M, N, K, max_nthr, &nthr_m, &nthr_n, &nthr_k, &MB, &NB, &KB);
assert(IMPLICATION(!mkldnn_thr_syncable(), nthr_k == 1));
const int nthr = nthr_mn * nthr_k;
const size_t ws_elems_per_thr = K * unroll_factor<data_t>::m;
const size_t ws_size_per_thr
- = utils::rnd_up(ws_elems_per_thr * sizeof(data_t), PAGE_4K);
+ = rnd_up(ws_elems_per_thr * sizeof(data_t), PAGE_4K);
if (do_copy) {
ws_buffers = (data_t*)malloc(nthr * ws_size_per_thr, PAGE_4K);
if (!ws_buffers)
do_copy = false;
}
- parallel(nthr, [&](const int ithr, const int nthr) {
+ auto get_thr_block = [&](int &from, int &to, int &myN, int NB, int N,
+ int ithr) {
+ from = NB * (ithr);
+ to = NB * (ithr + 1);
+ if (to > N)
+ to = N;
+ myN = to - from;
+ };
+
+ parallel_nd(nthr, [&](const int ithr) {
int ithr_mn = ithr % nthr_mn;
int ithr_m = ithr_mn % nthr_m;
int ithr_n = ithr_mn / nthr_m;
int m_from = 0, m_to = 0, myM = 0, n_from = 0, n_to = 0, myN = 0,
k_from = 0, k_to = 0, myK = 0;
- auto get_thr_block = [&](int &from, int &to, int &myN, int NB, int N,
- int ithr) {
- from = NB * (ithr);
- to = NB * (ithr + 1);
- if (to > N)
- to = N;
- myN = to - from;
- };
+
get_thr_block(m_from, m_to, myM, MB, M, ithr_m);
get_thr_block(n_from, n_to, myN, NB, N, ithr_n);
get_thr_block(k_from, k_to, myK, KB, K, ithr_k);
if (myM > 0 && myN > 0) {
data_t myBeta, *myC;
- int ld;
+ dim_t ld;
if (ithr_k == 0) {
myC = &(C[m_from + n_from * ldc]);
myBeta = beta;
ld = ldc;
} else {
- myC = c_buffers + MB * NB * (cbase + ithr_k - 1);
+ myC = c_buffers + (dim_t)MB * NB * (cbase + ithr_k - 1);
myBeta = 0.0f;
ld = MB;
}
}
}
}
+ });
- if (nthr_k > 1) {
- assert(mkldnn_thr_syncable());
- mkldnn_thr_barrier();
+ if (nthr_k > 1) {
+ parallel_nd(nthr, [&](const int ithr) {
+ int ithr_mn = ithr % nthr_mn;
+ int ithr_m = ithr_mn % nthr_m;
+ int ithr_k = ithr / nthr_mn;
+ int ithr_n = ithr_mn / nthr_m;
+
+ int n_from = 0, n_to = 0, myN = 0;
+ int m_from = 0, m_to = 0, myM = 0;
+
+ int cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1);
+
+ get_thr_block(n_from, n_to, myN, NB, N, ithr_n);
+ get_thr_block(m_from, m_to, myM, MB, M, ithr_m);
// sum matrices partitioned along K dimension
int offset = 0, block = 0;
gemm_utils::partition_unit_diff(ithr_k, nthr_k, myN, &offset,
&block);
for (int ik = 1; ik < nthr_k; ++ik) {
- data_t *myC = c_buffers + MB * (NB * (cbase + ik - 1) + offset);
+ data_t *myC = c_buffers
+ + MB * ((dim_t)NB * (cbase + ik - 1) + offset);
gemm_utils::sum_two_matrices(myM, block, myC, MB,
&C[m_from + (n_from + offset) * ldc], ldc);
}
- }
- });
+ });
+ }
if (bias) {
parallel_nd(N, M, [&](int i, int j) {
free(ws_buffers);
free(c_buffers);
+
+ return mkldnn_success;
}
-template void ref_gemm<float>(const char *transa_, const char *transb_,
+template mkldnn_status_t ref_gemm<float>(
+ const char *transa_, const char *transb_,
const int *M_, const int *N_, const int *K_, const float *alpha_,
const float *A, const int *lda_, const float *B, const int *ldb_,
const float *beta_, float *C, const int *ldc_, const float *bias);
-template void ref_gemm<double>(const char *transa_, const char *transb_,
+template mkldnn_status_t ref_gemm<double>(
+ const char *transa_, const char *transb_,
const int *M_, const int *N_, const int *K_, const double *alpha_,
const double *A, const int *lda_, const double *B, const int *ldb_,
const double *beta_, double *C, const int *ldc_, const double *bias);