Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / gemm / s8x8s32 / jit_avx512_core_gemm_s8s8s32.cpp
1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "common.hpp"
18 #include "nstl.hpp"
19 #include "math_utils.hpp"
20 #include "jit_avx512_core_gemm_s8u8s32.hpp"
21
22 namespace mkldnn {
23 namespace impl {
24 namespace cpu {
25
26 void compensation_init(const char *offsetC, int32_t *compensation, int len,
27         const int32_t *oc) {
28     bool OCisC = (*offsetC == 'C' || *offsetC == 'c');
29     bool OCisF = (*offsetC == 'F' || *offsetC == 'f');
30
31    if (OCisF && (*oc) != 0) {
32        for (int i = 0; i < len; i++)
33            compensation[i] = *oc;
34    } else if (OCisC) {
35        for (int i = 0; i < len; i++)
36            compensation[i] = oc[i];
37    } else {
38        parallel_nd(len, [=](int i) { compensation[i] = 0; });
39    }
40 }
41
42 void compensation_compute(bool transa, int m, int k, float alpha,
43         const int8_t *a, int lda, int32_t *compensation) {
44     if (!transa) {
45         const int L2_cache_size = get_cache_size(2, true);
46         const int blocking_factor = nstl::min(k, L2_cache_size / lda + 1);
47         const int npanels = k / blocking_factor;
48         const bool has_tile = k % blocking_factor > 0;
49
50         parallel_nd(npanels, m, [&](int j, int i) {
51             int32_t val = 0;
52             for (int jb = 0; jb < blocking_factor; jb++) {
53                 val += a[(i + (ptrdiff_t)j * blocking_factor * lda)
54                     + (ptrdiff_t)jb * lda];
55             }
56             if (alpha != 1.0f) {
57                 val = math::out_round<int32_t>(math::saturate<int32_t>(
58                     (double)val * alpha * -128.0));
59             } else {
60                 val *= -128;
61             }
62             mkldnn_fetch_and_add(&compensation[i], val);
63         });
64
65         if (has_tile) {
66             parallel_nd(m, [=](int i) {
67                 int32_t val = 0;
68                 for (int j = npanels * blocking_factor; j < k; j++) {
69                     val += a[i + (ptrdiff_t)j * lda];
70                 }
71                 if (alpha != 1.0f) {
72                     val = math::out_round<int32_t>(math::saturate<int32_t>(
73                         (double)val * alpha * -128.0));
74                 } else {
75                     val *= -128;
76                 }
77                 mkldnn_fetch_and_add(&compensation[i], val);
78             });
79         }
80     } else {
81         parallel_nd(m, [=](int i) {
82             int32_t val = 0;
83             for (int j = 0; j < k; j++) {
84                 val += a[j + (ptrdiff_t)i * lda];
85             }
86             if (alpha != 1.0f) {
87                 val = math::out_round<int32_t>(math::saturate<int32_t>(
88                     (double)val * alpha * -128.0));
89             } else {
90                 val *= -128;
91             }
92             compensation[i] += val;
93         });
94     }
95 }
96
97 void copy_and_shift_b(bool transb, int k, int n, uint8_t *b_u8, int ldb_u8,
98         const int8_t *b_s8, int ldb_s8) {
99     const int b_cols = transb ? k : n;
100
101     parallel_nd(b_cols, [=](int j) {
102         const int b_rows = transb ? n : k;
103
104         uint8_t *pb_u8 = b_u8 + j * ldb_u8;
105         const int8_t *pb_s8 = b_s8 + j * ldb_s8;
106
107         for (int i = 0; i < b_rows; i++) {
108             (*pb_u8) = (*pb_s8) + 128;
109             pb_u8++;
110             pb_s8++;
111         }
112     });
113 }
114
115 mkldnn_status_t jit_avx512_core_gemm_s8s8s32(
116         const char *transA, const char *transB, const char *offsetC,
117         const int *m, const int *n, const int *k,
118         const float *alpha, const int8_t *a, const int *lda, const int8_t *oa,
119         const int8_t *b, const int *ldb, const int8_t *ob,
120         const float *beta, int32_t *c, const int *ldc, const int32_t *oc) {
121     if (*oa != 0 || *ob != 0) return mkldnn_unimplemented;
122
123     int M = *m, N = *n, K = *k;
124     bool transa = (*transA == 'T' || *transA == 't');
125     bool transb = (*transB == 'T' || *transB == 't');
126     int ld = transb ? N : K;
127
128     uint8_t *b_u8 = (uint8_t *)malloc(sizeof(uint8_t) * K * N, 64);
129     int32_t *compensation = (int32_t *)malloc(sizeof(int32_t) * M, 64);
130
131     if (utils::any_null(b_u8, compensation)) {
132         free(b_u8);
133         free(compensation);
134         return mkldnn_out_of_memory;
135     }
136
137     compensation_init(offsetC, compensation, M, oc);
138     compensation_compute(transa, M, K, *alpha, a, *lda, compensation);
139     copy_and_shift_b(transb, K, N, b_u8, ld, b, *ldb);
140
141     mkldnn_gemm_s8u8s32(transA, transB, "C", m, n, k, alpha, a, lda, oa, b_u8,
142         &ld, ob, beta, c, ldc, compensation);
143
144     if ((*offsetC == 'R' || *offsetC == 'r'))
145         parallel_nd(M, N,
146             [=](int i, int j) { c[i + (ptrdiff_t)j * *ldc] += oc[j]; });
147
148     free(b_u8);
149     free(compensation);
150
151     return mkldnn_success;
152 }
153 }
154 }
155 }