1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
19 #include "math_utils.hpp"
20 #include "jit_avx512_core_gemm_s8u8s32.hpp"
26 void compensation_init(const char *offsetC, int32_t *compensation, int len,
28 bool OCisC = (*offsetC == 'C' || *offsetC == 'c');
29 bool OCisF = (*offsetC == 'F' || *offsetC == 'f');
31 if (OCisF && (*oc) != 0) {
32 for (int i = 0; i < len; i++)
33 compensation[i] = *oc;
35 for (int i = 0; i < len; i++)
36 compensation[i] = oc[i];
38 parallel_nd(len, [=](int i) { compensation[i] = 0; });
42 void compensation_compute(bool transa, int m, int k, float alpha,
43 const int8_t *a, int lda, int32_t *compensation) {
45 const int L2_cache_size = get_cache_size(2, true);
46 const int blocking_factor = nstl::min(k, L2_cache_size / lda + 1);
47 const int npanels = k / blocking_factor;
48 const bool has_tile = k % blocking_factor > 0;
50 parallel_nd(npanels, m, [&](int j, int i) {
52 for (int jb = 0; jb < blocking_factor; jb++) {
53 val += a[(i + (ptrdiff_t)j * blocking_factor * lda)
54 + (ptrdiff_t)jb * lda];
57 val = math::out_round<int32_t>(math::saturate<int32_t>(
58 (double)val * alpha * -128.0));
62 mkldnn_fetch_and_add(&compensation[i], val);
66 parallel_nd(m, [=](int i) {
68 for (int j = npanels * blocking_factor; j < k; j++) {
69 val += a[i + (ptrdiff_t)j * lda];
72 val = math::out_round<int32_t>(math::saturate<int32_t>(
73 (double)val * alpha * -128.0));
77 mkldnn_fetch_and_add(&compensation[i], val);
81 parallel_nd(m, [=](int i) {
83 for (int j = 0; j < k; j++) {
84 val += a[j + (ptrdiff_t)i * lda];
87 val = math::out_round<int32_t>(math::saturate<int32_t>(
88 (double)val * alpha * -128.0));
92 compensation[i] += val;
97 void copy_and_shift_b(bool transb, int k, int n, uint8_t *b_u8, int ldb_u8,
98 const int8_t *b_s8, int ldb_s8) {
99 const int b_cols = transb ? k : n;
101 parallel_nd(b_cols, [=](int j) {
102 const int b_rows = transb ? n : k;
104 uint8_t *pb_u8 = b_u8 + j * ldb_u8;
105 const int8_t *pb_s8 = b_s8 + j * ldb_s8;
107 for (int i = 0; i < b_rows; i++) {
108 (*pb_u8) = (*pb_s8) + 128;
115 mkldnn_status_t jit_avx512_core_gemm_s8s8s32(
116 const char *transA, const char *transB, const char *offsetC,
117 const int *m, const int *n, const int *k,
118 const float *alpha, const int8_t *a, const int *lda, const int8_t *oa,
119 const int8_t *b, const int *ldb, const int8_t *ob,
120 const float *beta, int32_t *c, const int *ldc, const int32_t *oc) {
121 if (*oa != 0 || *ob != 0) return mkldnn_unimplemented;
123 int M = *m, N = *n, K = *k;
124 bool transa = (*transA == 'T' || *transA == 't');
125 bool transb = (*transB == 'T' || *transB == 't');
126 int ld = transb ? N : K;
128 uint8_t *b_u8 = (uint8_t *)malloc(sizeof(uint8_t) * K * N, 64);
129 int32_t *compensation = (int32_t *)malloc(sizeof(int32_t) * M, 64);
131 if (utils::any_null(b_u8, compensation)) {
134 return mkldnn_out_of_memory;
137 compensation_init(offsetC, compensation, M, oc);
138 compensation_compute(transa, M, K, *alpha, a, *lda, compensation);
139 copy_and_shift_b(transb, K, N, b_u8, ld, b, *ldb);
141 mkldnn_gemm_s8u8s32(transA, transB, "C", m, n, k, alpha, a, lda, oa, b_u8,
142 &ld, ob, beta, c, ldc, compensation);
144 if ((*offsetC == 'R' || *offsetC == 'r'))
146 [=](int i, int j) { c[i + (ptrdiff_t)j * *ldc] += oc[j]; });
151 return mkldnn_success;