Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / tests / gtests / test_gemm_common.hpp
1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef TEST_GEMM_COMMON_H
18 #define TEST_GEMM_COMMON_H
19
20 #include "mkldnn_test_common.hpp"
21 #include "gtest/gtest.h"
22
23 #include "mkldnn_types.h"
24 #include "mkldnn.h"
25
26 #define CONCAT_WITH_UNDERSCORE_(a,b) a ## _ ## b
27 #define CONCAT_WITH_UNDERSCORE(a,b) CONCAT_WITH_UNDERSCORE_(a,b)
28
29 #define INST_TEST_CASE_(str, ...) INSTANTIATE_TEST_CASE_P( \
30         str, gemm_test, ::testing::Values(__VA_ARGS__))
31 #define INST_TEST_CASE(str, ...) INST_TEST_CASE_( \
32         CONCAT_WITH_UNDERSCORE(str,TEST_CASE_NAME_PREFIX), __VA_ARGS__)
33
34 namespace mkldnn {
35
36 struct test_igemm_params {
37     char offsetc;
38     bool zero_oa;
39     bool zero_ob;
40     bool zero_oc;
41 };
42
43 struct test_params {
44     char transA;
45     char transB;
46     int M;
47     int N;
48     int K;
49     float alpha;
50     float beta;
51     int lda;
52     int ldb;
53     int ldc;
54
55     test_igemm_params igemm_params;
56     bool expect_to_fail;
57     mkldnn_status_t expected_status;
58 };
59
60 template <typename data_t>
61 void ref_gemm(const char *transa, const char *transb, int m, int n, int k,
62         const data_t alpha, const data_t *a, int lda, const data_t *b,
63         int ldb, data_t beta, data_t *c, int ldc) {
64
65     const bool tr_a = transa && (*transa == 'T' || *transa == 't');
66     const bool tr_b = transb && (*transb == 'T' || *transb == 't');
67
68     auto pa = [=] (int i, int j) { return a[j*lda + i]; };
69     auto pb = [=] (int i, int j) { return b[j*ldb + i]; };
70     auto pc = [=] (int i, int j) { return c[j*ldc + i]; };
71
72     mkldnn::impl::parallel_nd(m, n, [&](int im, int in) {
73         data_t c_elem = (beta == 0.) ? 0. : pc(im, in) * beta;
74
75         for (int ik = 0; ik < k; ik++) {
76             const data_t a_elem = tr_a ? pa(ik, im) : pa(im, ik);
77             const data_t b_elem = tr_b ? pb(in, ik) : pb(ik, in);
78             c_elem += alpha * a_elem * b_elem;
79         }
80         c[in*ldc + im] = c_elem;
81     });
82 }
83
84 template <typename b_dt>
85 void ref_gemm_s8x8s32(const char *transa, const char *transb,
86         const char *offsetc, int m, int n, int k, const float alpha,
87         int8_t *A, int lda, const int8_t *oa, b_dt *B, int ldb,
88         const int8_t *ob, const float beta, int32_t *C, int ldc,
89         const int32_t *oc) {
90
91     bool OCisR = (*offsetc == 'R' || *offsetc == 'r');
92     bool OCisC = (*offsetc == 'C' || *offsetc == 'c');
93     bool AisN = (*transa == 'N' || *transa == 'n');
94     bool BisN = (*transb == 'N' || *transb == 'n');
95
96     size_t sizeA = AisN ? lda * k : lda * m;
97     size_t sizeB = BisN ? ldb * n : ldb * k;
98     size_t sizeC = ldc * n;
99
100     double *dA = (double *)test_malloc(sizeA * sizeof(double));
101     double *dB = (double *)test_malloc(sizeB * sizeof(double));
102     double *dC = (double *)test_malloc(sizeC * sizeof(double));
103
104     auto da_setter = [=] (int i, int j, double v) { dA[j * lda + i] = v; };
105     auto db_setter = [=] (int i, int j, double v) { dB[j * ldb + i] = v; };
106
107     auto ia_accessor = [=] (int i, int j) { return A[j * lda + i]; };
108     auto ib_accessor = [=] (int i, int j) { return B[j * ldb + i]; };
109
110     const int a_rows = AisN ? m : k;
111     const int a_cols = AisN ? k : m;
112     mkldnn::impl::parallel_nd(a_cols, a_rows, [&](int j, int i) {
113         da_setter(i, j,
114             static_cast<double>(ia_accessor(i, j)) + static_cast<double>(oa[0]));
115     });
116
117     const int b_rows = BisN ? k : n;
118     const int b_cols = BisN ? n : k;
119     mkldnn::impl::parallel_nd(b_cols, b_rows, [&](int j, int i) {
120         db_setter(i, j,
121             static_cast<double>(ib_accessor(i, j)) + static_cast<double>(ob[0]));
122     });
123
124     ref_gemm(transa, transb, m, n, k, 1.0, dA, lda, dB, ldb, 0.0, dC, ldc);
125
126     auto i2d = [=] (int32_t v) { return static_cast<double>(v); };
127     auto f2d = [=] (float v) { return static_cast<double>(v); };
128
129     mkldnn::impl::parallel_nd(n, m, [&] (int j, int i) {
130         double coffset = OCisR ? i2d(oc[j]) : OCisC ? i2d(oc[i]) : i2d(oc[0]);
131         double val = ((beta == 0.0f) ? 0.0 : f2d(beta) * i2d(C[i + j * ldc]))
132             + f2d(alpha) * dC[i + j * ldc] + coffset;
133         C[i + j * ldc] =
134             static_cast<int32_t>(nearbyint(saturate<int32_t, double>(val)));
135     });
136
137     test_free((char *)dA);
138     test_free((char *)dB);
139     test_free((char *)dC);
140 }
141
142 template <typename b_dt, typename c_dt>
143 void compare(int m, int n, const c_dt *c, const c_dt *c_ref, int ldc,
144         float alpha = 1.0f, float beta = 0.0f, int k = 1) {
145     using data_type = memory::data_type;
146     mkldnn::impl::parallel_nd(n, ldc, [&](int i, int j) {
147         c_dt ref = c_ref[i*ldc + j];
148         c_dt got = c[i*ldc + j];
149         c_dt diff = got - ref;
150
151         if (data_traits<b_dt>::data_type == data_type::f32) {
152             c_dt e = (std::abs(ref) > 1e-4) ? diff / ref : diff;
153             EXPECT_NEAR(e, 0.0, 1e-4) << "Row: " << j << " Col: " << i;
154         } else {
155             // igemm
156             if (alpha == 1.0f) {
157                 EXPECT_NEAR(diff, 0, 1) << "Row: " << j << " Col: " << i;
158             } else {
159                 if (data_traits<b_dt>::data_type == data_type::u8) {
160                     c_dt eps = k / 1000 + 1;
161                     EXPECT_NEAR(diff, 0, eps) << "Row: " << j << " Col: " << i;
162                 } else if (data_traits<b_dt>::data_type == data_type::s8) {
163                     c_dt eps = k / 500 + 1;
164                     EXPECT_NEAR(diff, 0, eps) << "Row: " << j << " Col: " << i;
165                 }
166             }
167         }
168     });
169 }
170
171 inline void get_matrix_size(const test_params &p, size_t &sizeA,
172         size_t &sizeB, size_t &sizeC) {
173     const bool tr_a = (p.transA == 'T' || p.transA == 't');
174     const bool tr_b = (p.transB == 'T' || p.transB == 't');
175     sizeA = !tr_a ? p.lda * p.K : p.lda * p.M,
176     sizeB = !tr_b ? p.ldb * p.N : p.ldb * p.K,
177     sizeC = p.ldc * p.N;
178 }
179
180 template <typename T>
181 inline T* get_matrix_buffer(size_t n) {
182     return (T*)test_malloc(n * sizeof(T));
183 }
184
185 template <typename a_dt, typename b_dt, typename c_dt>
186 inline void fill_matrix(const test_params &p, size_t sizeA, size_t sizeB,
187         size_t sizeC, size_t sizeco, a_dt *A, b_dt *B, c_dt *C, a_dt *oa,
188         a_dt *ob, c_dt *oc) {
189     fill_data<a_dt>(sizeA, A);
190     fill_data<b_dt>(sizeB, B);
191     fill_data<c_dt>(sizeC, C);
192     if (oa != nullptr && ob != nullptr && oc != nullptr) {
193         if (p.igemm_params.zero_oa) (*oa) = 0;
194         else fill_data<a_dt>(1, oa);
195
196         if (p.igemm_params.zero_ob) (*ob) = 0;
197         else fill_data<a_dt>(1, ob);
198
199         if (p.igemm_params.zero_oc) {
200             for (size_t i = 0; i < sizeco; i++)
201                 oc[i] = 0;
202         } else fill_data<c_dt>(sizeco, oc);
203     }
204 }
205
206 template <typename a_dt, typename b_dt, typename c_dt>
207 void run_test_gemm(const test_params &p) {}
208
209 template <>
210 void run_test_gemm<int8_t, uint8_t, int32_t>(const test_params &p) {
211     size_t sizeA, sizeB, sizeC;
212     get_matrix_size(p, sizeA, sizeB, sizeC);
213
214     int8_t  *A = get_matrix_buffer<int8_t>(sizeA);
215     uint8_t *B = get_matrix_buffer<uint8_t>(sizeB);
216     int32_t *C = get_matrix_buffer<int32_t>(sizeC);
217     int32_t *C_ref = get_matrix_buffer<int32_t>(sizeC);
218
219     bool OCisR = (p.igemm_params.offsetc == 'R' || p.igemm_params.offsetc == 'r');
220     bool OCisC = (p.igemm_params.offsetc == 'C' || p.igemm_params.offsetc == 'c');
221     size_t sizeco = OCisR ? p.N : OCisC ? p.M : 1;
222
223     int8_t oa, ob;
224     int32_t *oc = get_matrix_buffer<int32_t>(sizeco);
225
226     fill_matrix<int8_t, uint8_t, int32_t>(p, sizeA, sizeB, sizeC, sizeco,
227         A, B, C, &oa, &ob, oc);
228
229     mkldnn::impl::parallel_nd(p.ldc * p.N,
230         [&](int i) { C_ref[i] = static_cast<int32_t>(C[i]); });
231
232     auto status = mkldnn_gemm_s8u8s32(&p.transA, &p.transB, &p.igemm_params.offsetc,
233         &p.M, &p.N, &p.K, &p.alpha, A, &p.lda, &oa, B, &p.ldb, &ob,
234         &p.beta, C, &p.ldc, oc);
235
236     if (status != mkldnn_success)
237         throw error(status, "mkldnn_gemm_s8u8s32 returned error");
238
239     ref_gemm_s8x8s32<uint8_t>(&p.transA, &p.transB, &p.igemm_params.offsetc, p.M, p.N,
240         p.K, p.alpha, A, p.lda, &oa, B, p.ldb, &ob, p.beta, C_ref,
241         p.ldc, oc);
242
243     compare<uint8_t, int32_t>(p.M, p.N, C, C_ref, p.ldc, p.alpha, p.beta, p.K);
244
245     test_free((char *)A);
246     test_free((char *)B);
247     test_free((char *)C);
248     test_free((char *)C_ref);
249     test_free((char *)oc);
250 }
251
252 template <>
253 void run_test_gemm<int8_t, int8_t, int32_t>(const test_params &p) {
254     size_t sizeA, sizeB, sizeC;
255     get_matrix_size(p, sizeA, sizeB, sizeC);
256
257     int8_t  *A = get_matrix_buffer<int8_t>(sizeA);
258     int8_t  *B = get_matrix_buffer<int8_t>(sizeB);
259     int32_t *C = get_matrix_buffer<int32_t>(sizeC);
260     int32_t *C_ref = get_matrix_buffer<int32_t>(sizeC);
261
262     bool OCisR = (p.igemm_params.offsetc == 'R' || p.igemm_params.offsetc == 'r');
263     bool OCisC = (p.igemm_params.offsetc == 'C' || p.igemm_params.offsetc == 'c');
264     size_t sizeco = OCisR ? p.N : OCisC ? p.M : 1;
265
266     int8_t oa, ob;
267     int32_t* oc = get_matrix_buffer<int32_t>(sizeco);
268
269     fill_matrix<int8_t, int8_t, int32_t>(p, sizeA, sizeB, sizeC, sizeco, A, B, C,
270         &oa, &ob, oc);
271
272     mkldnn::impl::parallel_nd(p.ldc * p.N,
273         [&](int i) { C_ref[i] = static_cast<int32_t>(C[i]); });
274
275     auto status = mkldnn_gemm_s8s8s32(&p.transA, &p.transB, &p.igemm_params.offsetc,
276         &p.M, &p.N, &p.K, &p.alpha, A, &p.lda, &oa, B, &p.ldb, &ob,
277         &p.beta, C, &p.ldc, oc);
278
279     if (status != mkldnn_success)
280         throw error(status, "mkldnn_gemm_s8s8s32 returned error");
281
282     ref_gemm_s8x8s32<int8_t>(&p.transA, &p.transB, &p.igemm_params.offsetc, p.M, p.N,
283         p.K, p.alpha, A, p.lda, &oa, B, p.ldb, &ob, p.beta, C_ref,
284         p.ldc, oc);
285
286     compare<int8_t, int32_t>(p.M, p.N, C, C_ref, p.ldc, p.alpha, p.beta, p.K);
287
288     test_free((char *)A);
289     test_free((char *)B);
290     test_free((char *)C);
291     test_free((char *)C_ref);
292     test_free((char *)oc);
293 }
294
295 template <>
296 void run_test_gemm<float, float, float>(const test_params &p) {
297     size_t sizeA, sizeB, sizeC;
298     get_matrix_size(p, sizeA, sizeB, sizeC);
299
300     float *A = get_matrix_buffer<float>(sizeA);
301     float *B = get_matrix_buffer<float>(sizeB);
302     float *C = get_matrix_buffer<float>(sizeC);
303     float *C_ref = get_matrix_buffer<float>(sizeC);
304
305     fill_matrix<float, float, float>(p, sizeA, sizeB, sizeC, 0, A, B, C,
306         nullptr, nullptr, nullptr);
307
308     mkldnn::impl::parallel_nd(p.N * p.ldc, [&](int i) { C_ref[i] = C[i]; });
309
310     auto status = mkldnn_sgemm(&p.transA, &p.transB, &p.M, &p.N, &p.K, &p.alpha,
311         A, &p.lda, B, &p.ldb, &p.beta, C, &p.ldc);
312     if (status == mkldnn_success) {
313         ref_gemm(&p.transA, &p.transB, p.M, p.N, p.K, p.alpha, A, p.lda, B, p.ldb,
314             p.beta, C_ref, p.ldc);
315         compare<float, float>(p.M, p.N, C, C_ref, p.ldc);
316     }
317
318     test_free((char *)A);
319     test_free((char *)B);
320     test_free((char *)C);
321     test_free((char *)C_ref);
322
323     if (status != mkldnn_success)
324         throw error(status, "mkldnn_sgemm returned error");
325 }
326
327 template <typename a_dt, typename b_dt, typename c_dt>
328 class gemm_test_common: public ::testing::TestWithParam<test_params> {
329 protected:
330     virtual void SetUp() {
331         test_params p
332             = ::testing::TestWithParam<test_params>::GetParam();
333         catch_expected_failures([=](){Test();}, p.expect_to_fail,
334                     p.expected_status);
335     }
336     virtual void Test() {
337         test_params p = ::testing::TestWithParam<test_params>::GetParam();
338         run_test_gemm<a_dt, b_dt, c_dt>(p);
339     }
340 };
341 }
342 #endif