Add a section of how to link IE with CMake project (#99)
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / gemm / ref_gemm.cpp
1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "mkldnn_thread.hpp"
18 #include "nstl.hpp"
19 #include "utils.hpp"
20
21 #include "../jit_generator.hpp"
22
23 #include "gemm_utils.hpp"
24
25 namespace mkldnn {
26 namespace impl {
27 namespace cpu {
28
29 using namespace mkldnn::impl::utils;
30 using namespace gemm_utils;
31
32
33 template <typename data_t>
34 static void copy_A(
35         bool isTransA, int K, const data_t *A, const int lda, data_t *ws) {
36     for (int k = 0; k < K; k++) {
37         PRAGMA_OMP_SIMD()
38         for (int i = 0; i < gemm_utils::unroll_factor<data_t>::m; i++) {
39             ws[i] = isTransA ? A[i * lda + k] : A[i + k * lda];
40         }
41         ws += unroll_factor<data_t>::m;
42     }
43 }
44
45 template <typename data_t, bool isTransA, bool isTransB>
46 static void kernel_mxn(int K, const data_t *A, const int lda,
47         const data_t *B, const int ldb, data_t *C, const int ldc,
48         const data_t alpha, const data_t beta) {
49     data_t c[unroll_factor<data_t>::m * unroll_factor<data_t>::n] =
50         { static_cast<data_t>(0.) };
51     for (int k = 0; k < K; k++) {
52         for (int j = 0; j < unroll_factor<data_t>::n; j++) {
53             data_t b = isTransB ? B[j + k * ldb] : B[k + j * ldb];
54             PRAGMA_OMP_SIMD()
55             for (int i = 0; i < unroll_factor<data_t>::m; i++) {
56                 data_t a = isTransA ? A[i * lda + k] : A[i + lda * k];
57                 c[i + unroll_factor<data_t>::m * j] += a * b;
58             }
59         }
60     }
61     for (int j = 0; j < unroll_factor<data_t>::n; j++) {
62         PRAGMA_OMP_SIMD()
63         for (int i = 0; i < unroll_factor<data_t>::m; i++) {
64             C[i + j * ldc] = (beta == static_cast<data_t>(0.))
65             ? alpha * c[i + unroll_factor<data_t>::m * j]
66             : alpha * c[i + unroll_factor<data_t>::m * j]
67                 + beta * C[i + j * ldc];
68         }
69     }
70 }
71
72 template <typename data_t, bool isTransA, bool isTransB>
73 static void block_ker(const int M, const int N, const int K,
74         const data_t *A, const int lda, const data_t *B, const int ldb,
75         data_t *C, const int ldc, const data_t alpha, const data_t beta,
76         data_t *ws, bool do_copy) {
77     int Nu = rnd_dn(N, unroll_factor<data_t>::n);
78     int Mu = rnd_dn(M, unroll_factor<data_t>::m);
79     for (int i = 0; i < Mu; i += unroll_factor<data_t>::m) {
80         for (int j = 0; j < Nu; j += unroll_factor<data_t>::n) {
81             const data_t *b = isTransB ? &B[j] : &B[j * ldb];
82             const data_t *a = isTransA ? &A[i * lda] : &A[i];
83             if (do_copy) {
84                 if (j == 0) {
85                     copy_A<data_t>(isTransA, K, a, lda, ws);
86                 }
87                 kernel_mxn<data_t, false, isTransB>(
88                         K, ws, unroll_factor<data_t>::m, b, ldb,
89                         &C[i + j * ldc], ldc, alpha, beta);
90             } else {
91                 kernel_mxn<data_t, isTransA, isTransB>(
92                         K, a, lda, b, ldb, &C[i + j * ldc], ldc, alpha, beta);
93             }
94         }
95     }
96     // tail processing
97     for (int i = 0; i < M; i++) {
98         for (int j = Nu; j < N; j++) {
99             data_t c = beta == static_cast<data_t>(0.)
100                 ? static_cast<data_t>(0.)
101                 : beta * C[i + j * ldc];
102             for (int p = 0; p < K; p++) {
103                 data_t b = isTransB ? B[j + p * ldb] : B[p + j * ldb];
104                 data_t a = isTransA ? A[p + i * lda] : A[i + p * lda];
105                 c += alpha * a * b;
106             }
107             C[i + j * ldc] = c;
108         }
109     }
110     for (int i = Mu; i < M; i++) {
111         for (int j = 0; j < Nu; j++) {
112             data_t c = beta == static_cast<data_t>(0.)
113                 ? static_cast<data_t>(0.)
114                 : beta * C[i + j * ldc];
115             for (int p = 0; p < K; p++) {
116                 data_t b = isTransB ? B[j + p * ldb] : B[p + j * ldb];
117                 data_t a = isTransA ? A[p + i * lda] : A[i + p * lda];
118                 c += alpha * a * b;
119             }
120             C[i + j * ldc] = c;
121         }
122     }
123 }
124
125 template <typename data_t, bool isTransA, bool isTransB>
126 void gemm_ithr(const int M, const int N, const int K, const data_t alpha,
127         const data_t *A, const int lda, const data_t *B, const int ldb,
128         const data_t beta, data_t *C, const int ldc, bool do_copy, data_t *ws) {
129     constexpr int BM = gemm_traits<data_t, isTransA, isTransB>::BM;
130     constexpr int BN = gemm_traits<data_t, isTransA, isTransB>::BN;
131     constexpr int BK = gemm_traits<data_t, isTransA, isTransB>::BK;
132
133     const data_t *curA;
134     const data_t *curB;
135     data_t *curC;
136
137     if ((M <= 0) || (N <= 0))
138         return;
139
140     if ((K <= 0) || (alpha == static_cast<data_t>(0))) {
141         ptrdiff_t MN = (ptrdiff_t)N * M;
142         if (beta == static_cast<data_t>(0.)) {
143             for (ptrdiff_t j = 0; j < MN; j++)
144                 C[j] = static_cast<data_t>(0.);
145         } else if (beta != static_cast<data_t>(1.)) {
146             for (ptrdiff_t j = 0; j < MN; j++)
147                 C[j] *= beta;
148         }
149         return;
150     }
151
152     for (int Bk = 0; Bk < K; Bk += BK) {
153         int kb = nstl::min(K - Bk, BK);
154         for (int Bm = 0; Bm < M; Bm += BM) {
155             int mb = nstl::min(M - Bm, BM);
156             for (int Bn = 0; Bn < N; Bn += BN) {
157                 int nb = nstl::min(N - Bn, BN);
158                 curA = isTransA ? A + Bk + Bm * lda : A + Bm + Bk * lda;
159                 curB = isTransB ? B + Bn + Bk * ldb : B + Bk + Bn * ldb;
160                 curC = C + Bm + Bn * ldc;
161                 if (Bk == 0) {
162                     block_ker<data_t, isTransA, isTransB>(mb, nb, kb, curA, lda,
163                         curB, ldb, curC, ldc, alpha, beta, ws, do_copy);
164                 } else {
165                     block_ker<data_t, isTransA, isTransB>(mb, nb, kb, curA, lda,
166                         curB, ldb, curC, ldc, alpha, static_cast<data_t>(1.0),
167                         ws, do_copy);
168                 }
169             }
170         }
171     }
172 }
173
174 template <typename data_t>
175 void ref_gemm(const char *transa_, const char *transb_, const int *M_,
176         const int *N_, const int *K_, const data_t *alpha_, const data_t *A,
177         const int *lda_, const data_t *B, const int *ldb_, const data_t *beta_,
178         data_t *C, const int *ldc_, const data_t *bias) {
179     bool isTransA = (*transa_ == 'T' || *transa_ == 't');
180     bool isTransB = (*transb_ == 'T' || *transb_ == 't');
181     const int M = *M_, N = *N_, K = *K_, lda = *lda_, ldb = *ldb_, ldc = *ldc_;
182     const data_t alpha = *alpha_, beta = *beta_;
183
184     int max_nthr = mkldnn_in_parallel() ? 1 : mkldnn_get_max_threads();
185     int nthr_m, nthr_n, nthr_k;
186     int MB, NB, KB;
187     // thread balancing over M, N, K & size of blocking dimensions
188     gemm_utils::calc_nthr_nocopy_avx(
189             M, N, K, max_nthr, &nthr_m, &nthr_n, &nthr_k, &MB, &NB, &KB);
190     assert(IMPLICATION(!mkldnn_thr_syncable(), nthr_k == 1));
191
192     data_t *c_buffers = nullptr;
193     data_t *ws_buffers = nullptr;
194     if (nthr_k > 1) {
195         c_buffers = (data_t *)malloc(nthr_m * nthr_n * (nthr_k - 1) * MB * NB
196                 * sizeof(data_t), PAGE_4K);
197         if (!c_buffers) {
198             nthr_k = 1;
199             KB = K;
200         }
201     }
202
203     bool do_copy = (NB / unroll_factor<data_t>::n > 3);
204     const int nthr_mn = nthr_m * nthr_n;
205     const int nthr = nthr_mn * nthr_k;
206     const size_t ws_elems_per_thr = K * unroll_factor<data_t>::m;
207     const size_t ws_size_per_thr
208             = utils::rnd_up(ws_elems_per_thr * sizeof(data_t), PAGE_4K);
209     if (do_copy) {
210         ws_buffers = (data_t*)malloc(nthr * ws_size_per_thr, PAGE_4K);
211         if (!ws_buffers)
212             do_copy = false;
213     }
214
215     parallel(nthr, [&](const int ithr, const int nthr) {
216         int ithr_mn = ithr % nthr_mn;
217         int ithr_m = ithr_mn % nthr_m;
218         int ithr_n = ithr_mn / nthr_m;
219         int ithr_k = ithr / nthr_mn;
220
221         int cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1);
222
223         data_t *ws = do_copy
224                 ? ws_buffers + ithr * ws_size_per_thr / sizeof(data_t)
225                 : nullptr;
226
227         int m_from = 0, m_to = 0, myM = 0, n_from = 0, n_to = 0, myN = 0,
228                 k_from = 0, k_to = 0, myK = 0;
229         auto get_thr_block = [&](int &from, int &to, int &myN, int NB, int N,
230                 int ithr) {
231             from = NB * (ithr);
232             to = NB * (ithr + 1);
233             if (to > N)
234                 to = N;
235             myN = to - from;
236         };
237         get_thr_block(m_from, m_to, myM, MB, M, ithr_m);
238         get_thr_block(n_from, n_to, myN, NB, N, ithr_n);
239         get_thr_block(k_from, k_to, myK, KB, K, ithr_k);
240
241         if (myM > 0 && myN > 0) {
242             data_t myBeta, *myC;
243             int ld;
244             if (ithr_k == 0) {
245                 myC = &(C[m_from + n_from * ldc]);
246                 myBeta = beta;
247                 ld = ldc;
248             } else {
249                 myC = c_buffers + MB * NB * (cbase + ithr_k - 1);
250                 myBeta = 0.0f;
251                 ld = MB;
252             }
253             const data_t *myA = isTransA
254                     ? &(A[k_from + m_from * lda])
255                     : &(A[m_from + k_from * lda]);
256             const data_t *myB = isTransB
257                     ? &(B[n_from + k_from * ldb])
258                     : &(B[k_from + n_from * ldb]);
259
260             if (!isTransA) {
261                 if (!isTransB) {
262                     gemm_ithr<data_t, false, false>(myM, myN, myK, alpha, myA,
263                         lda, myB, ldb, myBeta, myC, ld, do_copy, ws);
264                 } else {
265                     gemm_ithr<data_t, false, true>(myM, myN, myK, alpha, myA,
266                         lda, myB, ldb, myBeta, myC, ld, do_copy, ws);
267                 }
268             } else {
269                 if (!isTransB) {
270                     gemm_ithr<data_t, true, false>(myM, myN, myK, alpha, myA,
271                         lda, myB, ldb, myBeta, myC, ld, do_copy, ws);
272                 } else {
273                     gemm_ithr<data_t, true, true>(myM, myN, myK, alpha, myA,
274                         lda, myB, ldb, myBeta, myC, ld, do_copy, ws);
275                 }
276             }
277         }
278
279         if (nthr_k > 1) {
280             assert(mkldnn_thr_syncable());
281             mkldnn_thr_barrier();
282
283             // sum matrices partitioned along K dimension
284             int offset = 0, block = 0;
285             gemm_utils::partition_unit_diff(ithr_k, nthr_k, myN, &offset,
286                     &block);
287             for (int ik = 1; ik < nthr_k; ++ik) {
288                 data_t *myC = c_buffers + MB * (NB * (cbase + ik - 1) + offset);
289
290                 gemm_utils::sum_two_matrices(myM, block, myC, MB,
291                         &C[m_from + (n_from + offset) * ldc], ldc);
292             }
293         }
294     });
295
296     if (bias) {
297         parallel_nd(N, M, [&](int i, int j) {
298             C[i*ldc + j] += bias[j];
299         });
300     }
301
302     free(ws_buffers);
303     free(c_buffers);
304 }
305
306 template void ref_gemm<float>(const char *transa_, const char *transb_,
307         const int *M_, const int *N_, const int *K_, const float *alpha_,
308         const float *A, const int *lda_, const float *B, const int *ldb_,
309         const float *beta_, float *C, const int *ldc_, const float *bias);
310
311 template void ref_gemm<double>(const char *transa_, const char *transb_,
312         const int *M_, const int *N_, const int *K_, const double *alpha_,
313         const double *A, const int *lda_, const double *B, const int *ldb_,
314         const double *beta_, double *C, const int *ldc_, const double *bias);
315 }
316 }
317 }