1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
19 #include "mkldnn_traits.hpp"
22 #include "jit_generator.hpp"
26 #include "f32/jit_avx512_common_gemm_f32.hpp"
27 #include "f32/jit_avx_gemm_f32.hpp"
28 #include "f32/ref_gemm_f32.hpp"
30 #include "s8x8s32/jit_avx512_core_gemm_s8u8s32.hpp"
31 #include "s8x8s32/jit_avx512_core_gemm_s8s8s32.hpp"
32 #include "s8x8s32/ref_gemm_s8x8s32.hpp"
34 #include "os_blas.hpp"
36 /* USE_MKL USE_CBLAS effect
37 * ------- --------- ------
38 * yes yes use Intel(R) MKL CBLAS
40 * no yes system-dependent CBLAS
48 mkldnn_status_t check_gemm_input(const char *transa, const char *transb,
49 const int *M, const int *N, const int *K, const int *lda,
50 const int *ldb, const int *ldc, const float *alpha, const float *beta,
51 const bool with_bias) {
52 if (utils::any_null(transa, transb, M, N, K, lda, ldb, ldc, alpha, beta))
53 return mkldnn_invalid_arguments;
54 if (with_bias && *beta != 0)
55 return mkldnn_unimplemented;
56 bool consistency = true
57 && utils::one_of(*transa, 'T', 't', 'N', 'n')
58 && utils::one_of(*transb, 'T', 't', 'N', 'n')
64 return mkldnn_invalid_arguments;
65 bool isTransA = utils::one_of(*transa, 'T', 't');
66 bool isTransB = utils::one_of(*transb, 'T', 't');
67 int nrowA = isTransA ? *K : *M;
68 int nrowB = isTransB ? *N : *K;
70 && *lda >= nstl::max(1, nrowA)
71 && *ldb >= nstl::max(1, nrowB)
72 && *ldc >= nstl::max(1, *M);
74 return mkldnn_invalid_arguments;
76 return mkldnn_success;
79 mkldnn_status_t check_gemm_x8x8x32_input(const char *offsetc,
80 const char *transa, const char *transb, const int *M, const int *N,
81 const int *K, const int *lda, const int *ldb, const int *ldc,
82 const float *alpha, const float *beta, const bool with_bias) {
83 if (offsetc == nullptr)
84 return mkldnn_invalid_arguments;
85 if (!utils::one_of(*offsetc, 'F', 'f', 'C', 'c', 'R', 'r'))
86 return mkldnn_invalid_arguments;
88 return check_gemm_input(transa, transb, M, N, K, lda, ldb, ldc, alpha,
92 mkldnn_status_t extended_sgemm(const char *transa, const char *transb,
93 const int *M, const int *N, const int *K, const float *alpha,
94 const float *A, const int *lda, const float *B, const int *ldb,
95 const float *beta, float *C, const int *ldc,
96 const float *bias, const bool force_jit_gemm) {
97 mkldnn_status_t status = check_gemm_input(transa, transb, M, N, K,
98 lda, ldb, ldc, alpha, beta, bias != nullptr);
99 if (status != mkldnn_success)
103 if (!force_jit_gemm) {
104 bool trA = *transa == 't' || *transa == 'T';
105 bool trB = *transb == 't' || *transb == 'T';
106 CBLAS_TRANSPOSE Cblas_trA = trA ? CblasTrans : CblasNoTrans;
107 CBLAS_TRANSPOSE Cblas_trB = trB ? CblasTrans : CblasNoTrans;
108 cblas_sgemm(CblasColMajor, Cblas_trA, Cblas_trB,
109 *M, *N, *K, *alpha, A, *lda, B, *ldb, *beta, C, *ldc);
112 // Add bias if necessary (bias is applied to columns of C)
113 cblas_int incx = 1, incy = 1;
114 parallel_nd(*N, [&](int n) {
115 ptrdiff_t offset = (ptrdiff_t)n * (*ldc);
116 cblas_saxpy(*M, 1.0, bias, incx, C + offset, incy);
119 return mkldnn_success;
123 if (mayiuse(avx512_common))
124 return jit_avx512_common_gemm_f32(transa, transb,
125 M, N, K, alpha, A, lda, B, ldb, beta, C, ldc, bias);
126 else if (mayiuse(avx))
127 return jit_avx_gemm_f32(transa, transb,
128 M, N, K, alpha, A, lda, B, ldb, beta, C, ldc, bias);
130 return ref_gemm<float>(transa, transb,
131 M, N, K, alpha, A, lda, B, ldb, beta, C, ldc, bias);
134 template <typename b_dt>
135 mkldnn_status_t gemm_s8x8s32(const char *transa, const char *transb,
136 const char *offsetc, const int *M, const int *N, const int *K,
137 const float *alpha, const int8_t *A, const int *LDA, const int8_t *ao,
138 const b_dt *B, const int *LDB, const int8_t *bo, const float *beta,
139 int32_t *C, const int *LDC, const int32_t *co) {
140 mkldnn_status_t status = check_gemm_x8x8x32_input(offsetc, transa, transb,
141 M, N, K, LDA, LDB, LDC, alpha, beta, false);
142 if (status != mkldnn_success)
145 if (*M == 0 || *N == 0 || *K == 0)
146 return mkldnn_success;
149 bool OCisR = (*offsetc == 'R' || *offsetc == 'r');
150 bool OCisC = (*offsetc == 'C' || *offsetc == 'c');
151 bool AisN = (*transa == 'N' || *transa == 'n');
152 bool BisN = (*transb == 'N' || *transb == 'n');
154 if (data_traits<b_dt>::data_type == data_type::u8) {
155 CBLAS_TRANSPOSE Cblas_trA = AisN ? CblasNoTrans : CblasTrans;
156 CBLAS_TRANSPOSE Cblas_trB = BisN ? CblasNoTrans : CblasTrans;
157 CBLAS_OFFSET Cblas_offsetc =
163 cblas_gemm_s8u8s32(CblasColMajor, Cblas_trA, Cblas_trB, Cblas_offsetc,
164 *M, *N, *K, *alpha, A, *LDA, *ao, (uint8_t *)B, *LDB, *bo,
166 return mkldnn_success;
168 assert(data_traits<b_dt>::data_type == data_type::s8);
169 // TODO CBLAS implementation of gemm_s8s8s32 goes here.
170 // mkldnn_gemm_s8s8s32 doesn't support non-zero ao and bo
171 if ((mayiuse(avx512_core) || mayiuse(avx512_core_vnni))
172 && *ao == 0 && *bo == 0) {
173 return jit_avx512_core_gemm_s8s8s32(transa, transb, offsetc, M,
174 N, K, alpha, A, LDA, ao, (int8_t *)B, LDB, bo, beta,
177 return ref_gemm_s8x8s32(transa, transb, offsetc, M, N, K,
178 alpha, A, LDA, ao, B, LDB, bo, beta, C, LDC, co);
182 cpu_isa_t isa = isa_any;
183 if (mayiuse(avx512_core_vnni)) {
184 isa = avx512_core_vnni;
185 } else if (mayiuse(avx512_core)) {
189 if (data_traits<b_dt>::data_type == data_type::u8) {
192 case avx512_core_vnni:
193 return jit_avx512_core_gemm_s8u8s32(transa, transb, offsetc, M,
194 N, K, alpha, A, LDA, ao, (uint8_t *)B, LDB, bo, beta,
197 return ref_gemm_s8x8s32(transa, transb, offsetc, M, N, K,
198 alpha, A, LDA, ao, B, LDB, bo, beta, C, LDC, co);
201 assert(data_traits<b_dt>::data_type == data_type::s8);
202 // mkldnn_gemm_s8s8s32 doesn't support non-zero ao and bo
203 if ((mayiuse(avx512_core) || mayiuse(avx512_core_vnni))
204 && *ao == 0 && *bo == 0) {
205 return jit_avx512_core_gemm_s8s8s32(transa, transb, offsetc, M,
206 N, K, alpha, A, LDA, ao, (int8_t *)B, LDB, bo, beta,
209 return ref_gemm_s8x8s32(transa, transb, offsetc, M, N, K,
210 alpha, A, LDA, ao, B, LDB, bo, beta, C, LDC, co);
220 using namespace mkldnn::impl;
221 using namespace mkldnn::impl::cpu;
223 mkldnn_status_t mkldnn_sgemm(const char *transa, const char *transb,
224 const int *M, const int *N, const int *K, const float *alpha,
225 const float *A, const int *lda, const float *B, const int *ldb,
226 const float *beta, float *C, const int *ldc) {
227 return extended_sgemm(
228 transa, transb, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
231 mkldnn_status_t mkldnn_gemm_s8u8s32(const char *transa, const char *transb,
232 const char *offsetc, const int *M, const int *N, const int *K,
233 const float *alpha, const int8_t *A, const int *lda, const int8_t *ao,
234 const uint8_t *B, const int *ldb, const int8_t *bo, const float *beta,
235 int32_t *C, const int *ldc, const int32_t *co) {
237 transa, transb, offsetc, M, N, K, alpha, A, lda, ao, B, ldb, bo,
241 mkldnn_status_t mkldnn_gemm_s8s8s32(const char *transa, const char *transb,
242 const char *offsetc, const int *M, const int *N, const int *K,
243 const float *alpha, const int8_t *A, const int *lda, const int8_t *ao,
244 const int8_t *B, const int *ldb, const int8_t *bo, const float *beta,
245 int32_t *C, const int *ldc, const int32_t *co) {
247 transa, transb, offsetc, M, N, K, alpha, A, lda, ao, B, ldb, bo,