1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "mkldnn_thread.hpp"
21 #include "../jit_generator.hpp"
23 #include "gemm_utils.hpp"
29 using namespace mkldnn::impl::utils;
30 using namespace gemm_utils;
33 template <typename data_t>
35 bool isTransA, int K, const data_t *A, const int lda, data_t *ws) {
36 for (int k = 0; k < K; k++) {
38 for (int i = 0; i < gemm_utils::unroll_factor<data_t>::m; i++) {
39 ws[i] = isTransA ? A[i * lda + k] : A[i + k * lda];
41 ws += unroll_factor<data_t>::m;
45 template <typename data_t, bool isTransA, bool isTransB>
46 static void kernel_mxn(int K, const data_t *A, const int lda,
47 const data_t *B, const int ldb, data_t *C, const int ldc,
48 const data_t alpha, const data_t beta) {
49 data_t c[unroll_factor<data_t>::m * unroll_factor<data_t>::n] =
50 { static_cast<data_t>(0.) };
51 for (int k = 0; k < K; k++) {
52 for (int j = 0; j < unroll_factor<data_t>::n; j++) {
53 data_t b = isTransB ? B[j + k * ldb] : B[k + j * ldb];
55 for (int i = 0; i < unroll_factor<data_t>::m; i++) {
56 data_t a = isTransA ? A[i * lda + k] : A[i + lda * k];
57 c[i + unroll_factor<data_t>::m * j] += a * b;
61 for (int j = 0; j < unroll_factor<data_t>::n; j++) {
63 for (int i = 0; i < unroll_factor<data_t>::m; i++) {
64 C[i + j * ldc] = (beta == static_cast<data_t>(0.))
65 ? alpha * c[i + unroll_factor<data_t>::m * j]
66 : alpha * c[i + unroll_factor<data_t>::m * j]
67 + beta * C[i + j * ldc];
72 template <typename data_t, bool isTransA, bool isTransB>
73 static void block_ker(const int M, const int N, const int K,
74 const data_t *A, const int lda, const data_t *B, const int ldb,
75 data_t *C, const int ldc, const data_t alpha, const data_t beta,
76 data_t *ws, bool do_copy) {
77 int Nu = rnd_dn(N, unroll_factor<data_t>::n);
78 int Mu = rnd_dn(M, unroll_factor<data_t>::m);
79 for (int i = 0; i < Mu; i += unroll_factor<data_t>::m) {
80 for (int j = 0; j < Nu; j += unroll_factor<data_t>::n) {
81 const data_t *b = isTransB ? &B[j] : &B[j * ldb];
82 const data_t *a = isTransA ? &A[i * lda] : &A[i];
85 copy_A<data_t>(isTransA, K, a, lda, ws);
87 kernel_mxn<data_t, false, isTransB>(
88 K, ws, unroll_factor<data_t>::m, b, ldb,
89 &C[i + j * ldc], ldc, alpha, beta);
91 kernel_mxn<data_t, isTransA, isTransB>(
92 K, a, lda, b, ldb, &C[i + j * ldc], ldc, alpha, beta);
97 for (int i = 0; i < M; i++) {
98 for (int j = Nu; j < N; j++) {
99 data_t c = beta == static_cast<data_t>(0.)
100 ? static_cast<data_t>(0.)
101 : beta * C[i + j * ldc];
102 for (int p = 0; p < K; p++) {
103 data_t b = isTransB ? B[j + p * ldb] : B[p + j * ldb];
104 data_t a = isTransA ? A[p + i * lda] : A[i + p * lda];
110 for (int i = Mu; i < M; i++) {
111 for (int j = 0; j < Nu; j++) {
112 data_t c = beta == static_cast<data_t>(0.)
113 ? static_cast<data_t>(0.)
114 : beta * C[i + j * ldc];
115 for (int p = 0; p < K; p++) {
116 data_t b = isTransB ? B[j + p * ldb] : B[p + j * ldb];
117 data_t a = isTransA ? A[p + i * lda] : A[i + p * lda];
125 template <typename data_t, bool isTransA, bool isTransB>
126 void gemm_ithr(const int M, const int N, const int K, const data_t alpha,
127 const data_t *A, const int lda, const data_t *B, const int ldb,
128 const data_t beta, data_t *C, const int ldc, bool do_copy, data_t *ws) {
129 constexpr int BM = gemm_traits<data_t, isTransA, isTransB>::BM;
130 constexpr int BN = gemm_traits<data_t, isTransA, isTransB>::BN;
131 constexpr int BK = gemm_traits<data_t, isTransA, isTransB>::BK;
137 if ((M <= 0) || (N <= 0))
140 if ((K <= 0) || (alpha == static_cast<data_t>(0))) {
141 ptrdiff_t MN = (ptrdiff_t)N * M;
142 if (beta == static_cast<data_t>(0.)) {
143 for (ptrdiff_t j = 0; j < MN; j++)
144 C[j] = static_cast<data_t>(0.);
145 } else if (beta != static_cast<data_t>(1.)) {
146 for (ptrdiff_t j = 0; j < MN; j++)
152 for (int Bk = 0; Bk < K; Bk += BK) {
153 int kb = nstl::min(K - Bk, BK);
154 for (int Bm = 0; Bm < M; Bm += BM) {
155 int mb = nstl::min(M - Bm, BM);
156 for (int Bn = 0; Bn < N; Bn += BN) {
157 int nb = nstl::min(N - Bn, BN);
158 curA = isTransA ? A + Bk + Bm * lda : A + Bm + Bk * lda;
159 curB = isTransB ? B + Bn + Bk * ldb : B + Bk + Bn * ldb;
160 curC = C + Bm + Bn * ldc;
162 block_ker<data_t, isTransA, isTransB>(mb, nb, kb, curA, lda,
163 curB, ldb, curC, ldc, alpha, beta, ws, do_copy);
165 block_ker<data_t, isTransA, isTransB>(mb, nb, kb, curA, lda,
166 curB, ldb, curC, ldc, alpha, static_cast<data_t>(1.0),
174 template <typename data_t>
175 void ref_gemm(const char *transa_, const char *transb_, const int *M_,
176 const int *N_, const int *K_, const data_t *alpha_, const data_t *A,
177 const int *lda_, const data_t *B, const int *ldb_, const data_t *beta_,
178 data_t *C, const int *ldc_, const data_t *bias) {
179 bool isTransA = (*transa_ == 'T' || *transa_ == 't');
180 bool isTransB = (*transb_ == 'T' || *transb_ == 't');
181 const int M = *M_, N = *N_, K = *K_, lda = *lda_, ldb = *ldb_, ldc = *ldc_;
182 const data_t alpha = *alpha_, beta = *beta_;
184 int max_nthr = mkldnn_in_parallel() ? 1 : mkldnn_get_max_threads();
185 int nthr_m, nthr_n, nthr_k;
187 // thread balancing over M, N, K & size of blocking dimensions
188 gemm_utils::calc_nthr_nocopy_avx(
189 M, N, K, max_nthr, &nthr_m, &nthr_n, &nthr_k, &MB, &NB, &KB);
190 assert(IMPLICATION(!mkldnn_thr_syncable(), nthr_k == 1));
192 data_t *c_buffers = nullptr;
193 data_t *ws_buffers = nullptr;
195 c_buffers = (data_t *)malloc(nthr_m * nthr_n * (nthr_k - 1) * MB * NB
196 * sizeof(data_t), PAGE_4K);
203 bool do_copy = (NB / unroll_factor<data_t>::n > 3);
204 const int nthr_mn = nthr_m * nthr_n;
205 const int nthr = nthr_mn * nthr_k;
206 const size_t ws_elems_per_thr = K * unroll_factor<data_t>::m;
207 const size_t ws_size_per_thr
208 = utils::rnd_up(ws_elems_per_thr * sizeof(data_t), PAGE_4K);
210 ws_buffers = (data_t*)malloc(nthr * ws_size_per_thr, PAGE_4K);
215 parallel(nthr, [&](const int ithr, const int nthr) {
216 int ithr_mn = ithr % nthr_mn;
217 int ithr_m = ithr_mn % nthr_m;
218 int ithr_n = ithr_mn / nthr_m;
219 int ithr_k = ithr / nthr_mn;
221 int cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1);
224 ? ws_buffers + ithr * ws_size_per_thr / sizeof(data_t)
227 int m_from = 0, m_to = 0, myM = 0, n_from = 0, n_to = 0, myN = 0,
228 k_from = 0, k_to = 0, myK = 0;
229 auto get_thr_block = [&](int &from, int &to, int &myN, int NB, int N,
232 to = NB * (ithr + 1);
237 get_thr_block(m_from, m_to, myM, MB, M, ithr_m);
238 get_thr_block(n_from, n_to, myN, NB, N, ithr_n);
239 get_thr_block(k_from, k_to, myK, KB, K, ithr_k);
241 if (myM > 0 && myN > 0) {
245 myC = &(C[m_from + n_from * ldc]);
249 myC = c_buffers + MB * NB * (cbase + ithr_k - 1);
253 const data_t *myA = isTransA
254 ? &(A[k_from + m_from * lda])
255 : &(A[m_from + k_from * lda]);
256 const data_t *myB = isTransB
257 ? &(B[n_from + k_from * ldb])
258 : &(B[k_from + n_from * ldb]);
262 gemm_ithr<data_t, false, false>(myM, myN, myK, alpha, myA,
263 lda, myB, ldb, myBeta, myC, ld, do_copy, ws);
265 gemm_ithr<data_t, false, true>(myM, myN, myK, alpha, myA,
266 lda, myB, ldb, myBeta, myC, ld, do_copy, ws);
270 gemm_ithr<data_t, true, false>(myM, myN, myK, alpha, myA,
271 lda, myB, ldb, myBeta, myC, ld, do_copy, ws);
273 gemm_ithr<data_t, true, true>(myM, myN, myK, alpha, myA,
274 lda, myB, ldb, myBeta, myC, ld, do_copy, ws);
280 assert(mkldnn_thr_syncable());
281 mkldnn_thr_barrier();
283 // sum matrices partitioned along K dimension
284 int offset = 0, block = 0;
285 gemm_utils::partition_unit_diff(ithr_k, nthr_k, myN, &offset,
287 for (int ik = 1; ik < nthr_k; ++ik) {
288 data_t *myC = c_buffers + MB * (NB * (cbase + ik - 1) + offset);
290 gemm_utils::sum_two_matrices(myM, block, myC, MB,
291 &C[m_from + (n_from + offset) * ldc], ldc);
297 parallel_nd(N, M, [&](int i, int j) {
298 C[i*ldc + j] += bias[j];
306 template void ref_gemm<float>(const char *transa_, const char *transb_,
307 const int *M_, const int *N_, const int *K_, const float *alpha_,
308 const float *A, const int *lda_, const float *B, const int *ldb_,
309 const float *beta_, float *C, const int *ldc_, const float *bias);
311 template void ref_gemm<double>(const char *transa_, const char *transb_,
312 const int *M_, const int *N_, const int *K_, const double *alpha_,
313 const double *A, const int *lda_, const double *B, const int *ldb_,
314 const double *beta_, double *C, const int *ldc_, const double *bias);