1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #ifndef TEST_GEMM_COMMON_H
18 #define TEST_GEMM_COMMON_H
20 #include "mkldnn_test_common.hpp"
21 #include "gtest/gtest.h"
23 #include "mkldnn_types.h"
26 #define CONCAT_WITH_UNDERSCORE_(a,b) a ## _ ## b
27 #define CONCAT_WITH_UNDERSCORE(a,b) CONCAT_WITH_UNDERSCORE_(a,b)
29 #define INST_TEST_CASE_(str, ...) INSTANTIATE_TEST_CASE_P( \
30 str, gemm_test, ::testing::Values(__VA_ARGS__))
31 #define INST_TEST_CASE(str, ...) INST_TEST_CASE_( \
32 CONCAT_WITH_UNDERSCORE(str,TEST_CASE_NAME_PREFIX), __VA_ARGS__)
36 struct test_igemm_params {
55 test_igemm_params igemm_params;
57 mkldnn_status_t expected_status;
60 template <typename data_t>
61 void ref_gemm(const char *transa, const char *transb, int m, int n, int k,
62 const data_t alpha, const data_t *a, int lda, const data_t *b,
63 int ldb, data_t beta, data_t *c, int ldc) {
65 const bool tr_a = transa && (*transa == 'T' || *transa == 't');
66 const bool tr_b = transb && (*transb == 'T' || *transb == 't');
68 auto pa = [=] (int i, int j) { return a[j*lda + i]; };
69 auto pb = [=] (int i, int j) { return b[j*ldb + i]; };
70 auto pc = [=] (int i, int j) { return c[j*ldc + i]; };
72 mkldnn::impl::parallel_nd(m, n, [&](int im, int in) {
73 data_t c_elem = (beta == 0.) ? 0. : pc(im, in) * beta;
75 for (int ik = 0; ik < k; ik++) {
76 const data_t a_elem = tr_a ? pa(ik, im) : pa(im, ik);
77 const data_t b_elem = tr_b ? pb(in, ik) : pb(ik, in);
78 c_elem += alpha * a_elem * b_elem;
80 c[in*ldc + im] = c_elem;
84 template <typename b_dt>
85 void ref_gemm_s8x8s32(const char *transa, const char *transb,
86 const char *offsetc, int m, int n, int k, const float alpha,
87 int8_t *A, int lda, const int8_t *oa, b_dt *B, int ldb,
88 const int8_t *ob, const float beta, int32_t *C, int ldc,
91 bool OCisR = (*offsetc == 'R' || *offsetc == 'r');
92 bool OCisC = (*offsetc == 'C' || *offsetc == 'c');
93 bool AisN = (*transa == 'N' || *transa == 'n');
94 bool BisN = (*transb == 'N' || *transb == 'n');
96 size_t sizeA = AisN ? lda * k : lda * m;
97 size_t sizeB = BisN ? ldb * n : ldb * k;
98 size_t sizeC = ldc * n;
100 double *dA = (double *)test_malloc(sizeA * sizeof(double));
101 double *dB = (double *)test_malloc(sizeB * sizeof(double));
102 double *dC = (double *)test_malloc(sizeC * sizeof(double));
104 auto da_setter = [=] (int i, int j, double v) { dA[j * lda + i] = v; };
105 auto db_setter = [=] (int i, int j, double v) { dB[j * ldb + i] = v; };
107 auto ia_accessor = [=] (int i, int j) { return A[j * lda + i]; };
108 auto ib_accessor = [=] (int i, int j) { return B[j * ldb + i]; };
110 const int a_rows = AisN ? m : k;
111 const int a_cols = AisN ? k : m;
112 mkldnn::impl::parallel_nd(a_cols, a_rows, [&](int j, int i) {
114 static_cast<double>(ia_accessor(i, j)) + static_cast<double>(oa[0]));
117 const int b_rows = BisN ? k : n;
118 const int b_cols = BisN ? n : k;
119 mkldnn::impl::parallel_nd(b_cols, b_rows, [&](int j, int i) {
121 static_cast<double>(ib_accessor(i, j)) + static_cast<double>(ob[0]));
124 ref_gemm(transa, transb, m, n, k, 1.0, dA, lda, dB, ldb, 0.0, dC, ldc);
126 auto i2d = [=] (int32_t v) { return static_cast<double>(v); };
127 auto f2d = [=] (float v) { return static_cast<double>(v); };
129 mkldnn::impl::parallel_nd(n, m, [&] (int j, int i) {
130 double coffset = OCisR ? i2d(oc[j]) : OCisC ? i2d(oc[i]) : i2d(oc[0]);
131 double val = ((beta == 0.0f) ? 0.0 : f2d(beta) * i2d(C[i + j * ldc]))
132 + f2d(alpha) * dC[i + j * ldc] + coffset;
134 static_cast<int32_t>(nearbyint(saturate<int32_t, double>(val)));
137 test_free((char *)dA);
138 test_free((char *)dB);
139 test_free((char *)dC);
142 template <typename b_dt, typename c_dt>
143 void compare(int m, int n, const c_dt *c, const c_dt *c_ref, int ldc,
144 float alpha = 1.0f, float beta = 0.0f, int k = 1) {
145 using data_type = memory::data_type;
146 mkldnn::impl::parallel_nd(n, ldc, [&](int i, int j) {
147 c_dt ref = c_ref[i*ldc + j];
148 c_dt got = c[i*ldc + j];
149 c_dt diff = got - ref;
151 if (data_traits<b_dt>::data_type == data_type::f32) {
152 c_dt e = (std::abs(ref) > 1e-4) ? diff / ref : diff;
153 EXPECT_NEAR(e, 0.0, 1e-4) << "Row: " << j << " Col: " << i;
157 EXPECT_NEAR(diff, 0, 1) << "Row: " << j << " Col: " << i;
159 if (data_traits<b_dt>::data_type == data_type::u8) {
160 c_dt eps = k / 1000 + 1;
161 EXPECT_NEAR(diff, 0, eps) << "Row: " << j << " Col: " << i;
162 } else if (data_traits<b_dt>::data_type == data_type::s8) {
163 c_dt eps = k / 500 + 1;
164 EXPECT_NEAR(diff, 0, eps) << "Row: " << j << " Col: " << i;
171 inline void get_matrix_size(const test_params &p, size_t &sizeA,
172 size_t &sizeB, size_t &sizeC) {
173 const bool tr_a = (p.transA == 'T' || p.transA == 't');
174 const bool tr_b = (p.transB == 'T' || p.transB == 't');
175 sizeA = !tr_a ? p.lda * p.K : p.lda * p.M,
176 sizeB = !tr_b ? p.ldb * p.N : p.ldb * p.K,
180 template <typename T>
181 inline T* get_matrix_buffer(size_t n) {
182 return (T*)test_malloc(n * sizeof(T));
185 template <typename a_dt, typename b_dt, typename c_dt>
186 inline void fill_matrix(const test_params &p, size_t sizeA, size_t sizeB,
187 size_t sizeC, size_t sizeco, a_dt *A, b_dt *B, c_dt *C, a_dt *oa,
188 a_dt *ob, c_dt *oc) {
189 fill_data<a_dt>(sizeA, A);
190 fill_data<b_dt>(sizeB, B);
191 fill_data<c_dt>(sizeC, C);
192 if (oa != nullptr && ob != nullptr && oc != nullptr) {
193 if (p.igemm_params.zero_oa) (*oa) = 0;
194 else fill_data<a_dt>(1, oa);
196 if (p.igemm_params.zero_ob) (*ob) = 0;
197 else fill_data<a_dt>(1, ob);
199 if (p.igemm_params.zero_oc) {
200 for (size_t i = 0; i < sizeco; i++)
202 } else fill_data<c_dt>(sizeco, oc);
206 template <typename a_dt, typename b_dt, typename c_dt>
207 void run_test_gemm(const test_params &p) {}
210 void run_test_gemm<int8_t, uint8_t, int32_t>(const test_params &p) {
211 size_t sizeA, sizeB, sizeC;
212 get_matrix_size(p, sizeA, sizeB, sizeC);
214 int8_t *A = get_matrix_buffer<int8_t>(sizeA);
215 uint8_t *B = get_matrix_buffer<uint8_t>(sizeB);
216 int32_t *C = get_matrix_buffer<int32_t>(sizeC);
217 int32_t *C_ref = get_matrix_buffer<int32_t>(sizeC);
219 bool OCisR = (p.igemm_params.offsetc == 'R' || p.igemm_params.offsetc == 'r');
220 bool OCisC = (p.igemm_params.offsetc == 'C' || p.igemm_params.offsetc == 'c');
221 size_t sizeco = OCisR ? p.N : OCisC ? p.M : 1;
224 int32_t *oc = get_matrix_buffer<int32_t>(sizeco);
226 fill_matrix<int8_t, uint8_t, int32_t>(p, sizeA, sizeB, sizeC, sizeco,
227 A, B, C, &oa, &ob, oc);
229 mkldnn::impl::parallel_nd(p.ldc * p.N,
230 [&](int i) { C_ref[i] = static_cast<int32_t>(C[i]); });
232 auto status = mkldnn_gemm_s8u8s32(&p.transA, &p.transB, &p.igemm_params.offsetc,
233 &p.M, &p.N, &p.K, &p.alpha, A, &p.lda, &oa, B, &p.ldb, &ob,
234 &p.beta, C, &p.ldc, oc);
236 if (status != mkldnn_success)
237 throw error(status, "mkldnn_gemm_s8u8s32 returned error");
239 ref_gemm_s8x8s32<uint8_t>(&p.transA, &p.transB, &p.igemm_params.offsetc, p.M, p.N,
240 p.K, p.alpha, A, p.lda, &oa, B, p.ldb, &ob, p.beta, C_ref,
243 compare<uint8_t, int32_t>(p.M, p.N, C, C_ref, p.ldc, p.alpha, p.beta, p.K);
245 test_free((char *)A);
246 test_free((char *)B);
247 test_free((char *)C);
248 test_free((char *)C_ref);
249 test_free((char *)oc);
253 void run_test_gemm<int8_t, int8_t, int32_t>(const test_params &p) {
254 size_t sizeA, sizeB, sizeC;
255 get_matrix_size(p, sizeA, sizeB, sizeC);
257 int8_t *A = get_matrix_buffer<int8_t>(sizeA);
258 int8_t *B = get_matrix_buffer<int8_t>(sizeB);
259 int32_t *C = get_matrix_buffer<int32_t>(sizeC);
260 int32_t *C_ref = get_matrix_buffer<int32_t>(sizeC);
262 bool OCisR = (p.igemm_params.offsetc == 'R' || p.igemm_params.offsetc == 'r');
263 bool OCisC = (p.igemm_params.offsetc == 'C' || p.igemm_params.offsetc == 'c');
264 size_t sizeco = OCisR ? p.N : OCisC ? p.M : 1;
267 int32_t* oc = get_matrix_buffer<int32_t>(sizeco);
269 fill_matrix<int8_t, int8_t, int32_t>(p, sizeA, sizeB, sizeC, sizeco, A, B, C,
272 mkldnn::impl::parallel_nd(p.ldc * p.N,
273 [&](int i) { C_ref[i] = static_cast<int32_t>(C[i]); });
275 auto status = mkldnn_gemm_s8s8s32(&p.transA, &p.transB, &p.igemm_params.offsetc,
276 &p.M, &p.N, &p.K, &p.alpha, A, &p.lda, &oa, B, &p.ldb, &ob,
277 &p.beta, C, &p.ldc, oc);
279 if (status != mkldnn_success)
280 throw error(status, "mkldnn_gemm_s8s8s32 returned error");
282 ref_gemm_s8x8s32<int8_t>(&p.transA, &p.transB, &p.igemm_params.offsetc, p.M, p.N,
283 p.K, p.alpha, A, p.lda, &oa, B, p.ldb, &ob, p.beta, C_ref,
286 compare<int8_t, int32_t>(p.M, p.N, C, C_ref, p.ldc, p.alpha, p.beta, p.K);
288 test_free((char *)A);
289 test_free((char *)B);
290 test_free((char *)C);
291 test_free((char *)C_ref);
292 test_free((char *)oc);
296 void run_test_gemm<float, float, float>(const test_params &p) {
297 size_t sizeA, sizeB, sizeC;
298 get_matrix_size(p, sizeA, sizeB, sizeC);
300 float *A = get_matrix_buffer<float>(sizeA);
301 float *B = get_matrix_buffer<float>(sizeB);
302 float *C = get_matrix_buffer<float>(sizeC);
303 float *C_ref = get_matrix_buffer<float>(sizeC);
305 fill_matrix<float, float, float>(p, sizeA, sizeB, sizeC, 0, A, B, C,
306 nullptr, nullptr, nullptr);
308 mkldnn::impl::parallel_nd(p.N * p.ldc, [&](int i) { C_ref[i] = C[i]; });
310 auto status = mkldnn_sgemm(&p.transA, &p.transB, &p.M, &p.N, &p.K, &p.alpha,
311 A, &p.lda, B, &p.ldb, &p.beta, C, &p.ldc);
312 if (status == mkldnn_success) {
313 ref_gemm(&p.transA, &p.transB, p.M, p.N, p.K, p.alpha, A, p.lda, B, p.ldb,
314 p.beta, C_ref, p.ldc);
315 compare<float, float>(p.M, p.N, C, C_ref, p.ldc);
318 test_free((char *)A);
319 test_free((char *)B);
320 test_free((char *)C);
321 test_free((char *)C_ref);
323 if (status != mkldnn_success)
324 throw error(status, "mkldnn_sgemm returned error");
327 template <typename a_dt, typename b_dt, typename c_dt>
328 class gemm_test_common: public ::testing::TestWithParam<test_params> {
330 virtual void SetUp() {
332 = ::testing::TestWithParam<test_params>::GetParam();
333 catch_expected_failures([=](){Test();}, p.expect_to_fail,
336 virtual void Test() {
337 test_params p = ::testing::TestWithParam<test_params>::GetParam();
338 run_test_gemm<a_dt, b_dt, c_dt>(p);