Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / gemm / gemm.cpp
1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "mkldnn.h"
18
19 #include "mkldnn_traits.hpp"
20 #include "nstl.hpp"
21
22 #include "jit_generator.hpp"
23
24 #include "gemm.hpp"
25
26 #include "f32/jit_avx512_common_gemm_f32.hpp"
27 #include "f32/jit_avx_gemm_f32.hpp"
28 #include "f32/ref_gemm_f32.hpp"
29
30 #include "s8x8s32/jit_avx512_core_gemm_s8u8s32.hpp"
31 #include "s8x8s32/jit_avx512_core_gemm_s8s8s32.hpp"
32 #include "s8x8s32/ref_gemm_s8x8s32.hpp"
33
34 #include "os_blas.hpp"
35
36 /* USE_MKL      USE_CBLAS       effect
37  * -------      ---------       ------
38  * yes          yes             use Intel(R) MKL CBLAS
39  * yes          no              use jit
40  * no           yes             system-dependent CBLAS
41  * no           no              use jit
42  */
43
44 namespace mkldnn {
45 namespace impl {
46 namespace cpu {
47
48 mkldnn_status_t check_gemm_input(const char *transa, const char *transb,
49         const int *M, const int *N, const int *K, const int *lda,
50         const int *ldb, const int *ldc, const float *alpha, const float *beta,
51         const bool with_bias) {
52     if (utils::any_null(transa, transb, M, N, K, lda, ldb, ldc, alpha, beta))
53         return mkldnn_invalid_arguments;
54     if (with_bias && *beta != 0)
55         return mkldnn_unimplemented;
56     bool consistency = true
57         && utils::one_of(*transa, 'T', 't', 'N', 'n')
58         && utils::one_of(*transb, 'T', 't', 'N', 'n')
59         && *M >= 0
60         && *N >= 0
61         && *K >= 0;
62
63     if (!consistency)
64         return mkldnn_invalid_arguments;
65     bool isTransA = utils::one_of(*transa, 'T', 't');
66     bool isTransB = utils::one_of(*transb, 'T', 't');
67     int nrowA = isTransA ? *K : *M;
68     int nrowB = isTransB ? *N : *K;
69     consistency = true
70         && *lda >= nstl::max(1, nrowA)
71         && *ldb >= nstl::max(1, nrowB)
72         && *ldc >= nstl::max(1, *M);
73     if (!consistency)
74         return mkldnn_invalid_arguments;
75
76     return mkldnn_success;
77 }
78
79 mkldnn_status_t check_gemm_x8x8x32_input(const char *offsetc,
80         const char *transa, const char *transb, const int *M, const int *N,
81         const int *K, const int *lda, const int *ldb, const int *ldc,
82         const float *alpha, const float *beta, const bool with_bias) {
83     if (offsetc == nullptr)
84         return mkldnn_invalid_arguments;
85     if (!utils::one_of(*offsetc, 'F', 'f', 'C', 'c', 'R', 'r'))
86         return mkldnn_invalid_arguments;
87
88     return check_gemm_input(transa, transb, M, N, K, lda, ldb, ldc, alpha,
89         beta, with_bias);
90 }
91
92 mkldnn_status_t extended_sgemm(const char *transa, const char *transb,
93         const int *M, const int *N, const int *K, const float *alpha,
94         const float *A, const int *lda, const float *B, const int *ldb,
95         const float *beta, float *C, const int *ldc,
96         const float *bias, const bool force_jit_gemm) {
97     mkldnn_status_t status = check_gemm_input(transa, transb, M, N, K,
98             lda, ldb, ldc, alpha, beta, bias != nullptr);
99     if (status != mkldnn_success)
100         return status;
101
102 #ifdef USE_CBLAS
103     if (!force_jit_gemm) {
104         bool trA = *transa == 't' || *transa == 'T';
105         bool trB = *transb == 't' || *transb == 'T';
106         CBLAS_TRANSPOSE Cblas_trA = trA ? CblasTrans : CblasNoTrans;
107         CBLAS_TRANSPOSE Cblas_trB = trB ? CblasTrans : CblasNoTrans;
108         cblas_sgemm(CblasColMajor, Cblas_trA, Cblas_trB,
109                 *M, *N, *K, *alpha, A, *lda, B, *ldb, *beta, C, *ldc);
110
111         if (bias) {
112             // Add bias if necessary (bias is applied to columns of C)
113             cblas_int incx = 1, incy = 1;
114             parallel_nd(*N, [&](int n) {
115                 ptrdiff_t offset = (ptrdiff_t)n * (*ldc);
116                 cblas_saxpy(*M, 1.0, bias, incx, C + offset, incy);
117             });
118         }
119         return mkldnn_success;
120     }
121 #endif
122
123     if (mayiuse(avx512_common))
124         return jit_avx512_common_gemm_f32(transa, transb,
125                 M, N, K, alpha, A, lda, B, ldb, beta, C, ldc, bias);
126     else if (mayiuse(avx))
127         return jit_avx_gemm_f32(transa, transb,
128                 M, N, K, alpha, A, lda, B, ldb, beta, C, ldc, bias);
129     else
130         return ref_gemm<float>(transa, transb,
131                 M, N, K, alpha, A, lda, B, ldb, beta, C, ldc, bias);
132 }
133
134 template <typename b_dt>
135 mkldnn_status_t gemm_s8x8s32(const char *transa, const char *transb,
136         const char *offsetc, const int *M, const int *N, const int *K,
137         const float *alpha, const int8_t *A, const int *LDA, const int8_t *ao,
138         const b_dt *B, const int *LDB, const int8_t *bo, const float *beta,
139         int32_t *C, const int *LDC, const int32_t *co) {
140     mkldnn_status_t status = check_gemm_x8x8x32_input(offsetc, transa, transb,
141         M, N, K, LDA, LDB, LDC, alpha, beta, false);
142     if (status != mkldnn_success)
143         return status;
144
145     if (*M == 0 || *N == 0 || *K == 0)
146         return mkldnn_success;
147
148 #if USE_MKL_IGEMM
149         bool OCisR = (*offsetc == 'R' || *offsetc == 'r');
150         bool OCisC = (*offsetc == 'C' || *offsetc == 'c');
151         bool AisN = (*transa == 'N' || *transa == 'n');
152         bool BisN = (*transb == 'N' || *transb == 'n');
153
154     if (data_traits<b_dt>::data_type == data_type::u8) {
155         CBLAS_TRANSPOSE Cblas_trA = AisN ? CblasNoTrans : CblasTrans;
156         CBLAS_TRANSPOSE Cblas_trB = BisN ? CblasNoTrans : CblasTrans;
157         CBLAS_OFFSET Cblas_offsetc =
158             OCisR
159             ? CblasRowOffset
160             : OCisC
161             ? CblasColOffset
162             : CblasFixOffset;
163         cblas_gemm_s8u8s32(CblasColMajor, Cblas_trA, Cblas_trB, Cblas_offsetc,
164                 *M, *N, *K, *alpha, A, *LDA, *ao, (uint8_t *)B, *LDB, *bo,
165                 *beta, C, *LDC, co);
166         return mkldnn_success;
167     } else {
168         assert(data_traits<b_dt>::data_type == data_type::s8);
169         // TODO CBLAS implementation of gemm_s8s8s32 goes here.
170         // mkldnn_gemm_s8s8s32 doesn't support non-zero ao and bo
171         if ((mayiuse(avx512_core) || mayiuse(avx512_core_vnni))
172                 && *ao == 0 && *bo == 0) {
173             return jit_avx512_core_gemm_s8s8s32(transa, transb, offsetc, M,
174                     N, K, alpha, A, LDA, ao, (int8_t *)B, LDB, bo, beta,
175                     C, LDC, co);
176         } else {
177             return ref_gemm_s8x8s32(transa, transb, offsetc, M, N, K,
178                     alpha, A, LDA, ao, B, LDB, bo, beta, C, LDC, co);
179         }
180     }
181 #else
182     cpu_isa_t isa = isa_any;
183     if (mayiuse(avx512_core_vnni)) {
184         isa = avx512_core_vnni;
185     } else if (mayiuse(avx512_core)) {
186         isa = avx512_core;
187     }
188
189     if (data_traits<b_dt>::data_type == data_type::u8) {
190         switch (isa) {
191         case avx512_core:
192         case avx512_core_vnni:
193             return jit_avx512_core_gemm_s8u8s32(transa, transb, offsetc, M,
194                     N, K, alpha, A, LDA, ao, (uint8_t *)B, LDB, bo, beta,
195                     C, LDC, co);
196         default:
197             return ref_gemm_s8x8s32(transa, transb, offsetc, M, N, K,
198                     alpha, A, LDA, ao, B, LDB, bo, beta, C, LDC, co);
199         }
200     } else {
201         assert(data_traits<b_dt>::data_type == data_type::s8);
202         // mkldnn_gemm_s8s8s32 doesn't support non-zero ao and bo
203         if ((mayiuse(avx512_core) || mayiuse(avx512_core_vnni))
204                 && *ao == 0 && *bo == 0) {
205             return jit_avx512_core_gemm_s8s8s32(transa, transb, offsetc, M,
206                     N, K, alpha, A, LDA, ao, (int8_t *)B, LDB, bo, beta,
207                     C, LDC, co);
208         } else {
209             return ref_gemm_s8x8s32(transa, transb, offsetc, M, N, K,
210                     alpha, A, LDA, ao, B, LDB, bo, beta, C, LDC, co);
211         }
212     }
213 #endif
214 }
215
216 }
217 }
218 }
219
220 using namespace mkldnn::impl;
221 using namespace mkldnn::impl::cpu;
222
223 mkldnn_status_t mkldnn_sgemm(const char *transa, const char *transb,
224         const int *M, const int *N, const int *K, const float *alpha,
225         const float *A, const int *lda, const float *B, const int *ldb,
226         const float *beta, float *C, const int *ldc) {
227     return extended_sgemm(
228             transa, transb, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
229 }
230
231 mkldnn_status_t mkldnn_gemm_s8u8s32(const char *transa, const char *transb,
232         const char *offsetc, const int *M, const int *N, const int *K,
233         const float *alpha, const int8_t *A, const int *lda, const int8_t *ao,
234         const uint8_t *B, const int *ldb, const int8_t *bo, const float *beta,
235         int32_t *C, const int *ldc, const int32_t *co) {
236     return gemm_s8x8s32(
237         transa, transb, offsetc, M, N, K, alpha, A, lda, ao, B, ldb, bo,
238         beta, C, ldc, co);
239 }
240
241 mkldnn_status_t mkldnn_gemm_s8s8s32(const char *transa, const char *transb,
242         const char *offsetc, const int *M, const int *N, const int *K,
243         const float *alpha, const int8_t *A, const int *lda, const int8_t *ao,
244         const int8_t *B, const int *ldb, const int8_t *bo, const float *beta,
245         int32_t *C, const int *ldc, const int32_t *co) {
246     return gemm_s8x8s32(
247         transa, transb, offsetc, M, N, K, alpha, A, lda, ao, B, ldb, bo,
248         beta, C, ldc, co);
249 }